From a0f3c3749848079c1aafc4eccb3eefa2c999ec86 Mon Sep 17 00:00:00 2001 From: Jannis Mattheis Date: Fri, 11 Oct 2024 14:47:26 +0200 Subject: [PATCH] fix: race condition In client.go if room, ok := message.(outgoing.Room); ok { c.info.RoomID = room.ID } this part isn't thread safe. It could happen that user disconnected but wasn't removed from a room, because the disconnecting go routine couldn't see the roomID yet. --- ws/client.go | 6 --- ws/event_clientanswer.go | 10 ++-- ws/event_clientice.go | 10 ++-- ws/event_connected.go | 2 +- ws/event_create.go | 3 +- ws/event_disconnected.go | 9 ++-- ws/event_hostice.go | 10 ++-- ws/event_hostoffer.go | 10 ++-- ws/event_join.go | 3 +- ws/event_name.go | 14 ++--- ws/event_share.go | 14 ++--- ws/event_stop_share.go | 11 ++-- ws/rooms.go | 23 ++++++-- ws/rooms_test.go | 112 +++++++++++++++++++++++++++++++++++++++ 14 files changed, 163 insertions(+), 74 deletions(-) create mode 100644 ws/rooms_test.go diff --git a/ws/client.go b/ws/client.go index daa46ed3..3d3141b2 100644 --- a/ws/client.go +++ b/ws/client.go @@ -41,7 +41,6 @@ type ClientMessage struct { type ClientInfo struct { ID xid.ID - RoomID string Authenticated bool AuthenticatedUser string Write chan outgoing.Message @@ -60,7 +59,6 @@ func newClient(conn *websocket.Conn, req *http.Request, read chan ClientMessage, Authenticated: authenticated, AuthenticatedUser: authenticatedUser, ID: xid.New(), - RoomID: "", Addr: ip, Write: make(chan outgoing.Message, 1), }, @@ -158,10 +156,6 @@ func (c *Client) startWriteHandler(pingPeriod time.Duration) { continue } - if room, ok := message.(outgoing.Room); ok { - c.info.RoomID = room.ID - } - if err := writeJSON(c.conn, typed); err != nil { c.printWebSocketError("write", err) c.CloseOnError(websocket.CloseNormalClosure, "write error"+err.Error()) diff --git a/ws/event_clientanswer.go b/ws/event_clientanswer.go index eb4ad558..3965aa8d 100644 --- a/ws/event_clientanswer.go +++ b/ws/event_clientanswer.go @@ -16,13 +16,9 @@ func init() { type ClientAnswer outgoing.P2PMessage func (e *ClientAnswer) Execute(rooms *Rooms, current ClientInfo) error { - if current.RoomID == "" { - return fmt.Errorf("not in a room") - } - - room, ok := rooms.Rooms[current.RoomID] - if !ok { - return fmt.Errorf("room with id %s does not exist", current.RoomID) + room, err := rooms.CurrentRoom(current) + if err != nil { + return err } session, ok := room.Sessions[e.SID] diff --git a/ws/event_clientice.go b/ws/event_clientice.go index 9701ee96..4c0ac155 100644 --- a/ws/event_clientice.go +++ b/ws/event_clientice.go @@ -16,13 +16,9 @@ func init() { type ClientICE outgoing.P2PMessage func (e *ClientICE) Execute(rooms *Rooms, current ClientInfo) error { - if current.RoomID == "" { - return fmt.Errorf("not in a room") - } - - room, ok := rooms.Rooms[current.RoomID] - if !ok { - return fmt.Errorf("room with id %s does not exist", current.RoomID) + room, err := rooms.CurrentRoom(current) + if err != nil { + return err } session, ok := room.Sessions[e.SID] diff --git a/ws/event_connected.go b/ws/event_connected.go index 6758f926..1ac385e6 100644 --- a/ws/event_connected.go +++ b/ws/event_connected.go @@ -3,6 +3,6 @@ package ws type Connected struct{} func (e Connected) Execute(rooms *Rooms, current ClientInfo) error { - rooms.connected[current.ID] = true + rooms.connected[current.ID] = "" return nil } diff --git a/ws/event_create.go b/ws/event_create.go index ac9882fe..6f1daf21 100644 --- a/ws/event_create.go +++ b/ws/event_create.go @@ -23,7 +23,7 @@ type Create struct { } func (e *Create) Execute(rooms *Rooms, current ClientInfo) error { - if current.RoomID != "" { + if rooms.connected[current.ID] != "" { return fmt.Errorf("cannot join room, you are already in one") } @@ -74,6 +74,7 @@ func (e *Create) Execute(rooms *Rooms, current ClientInfo) error { }, }, } + rooms.connected[current.ID] = room.ID rooms.Rooms[e.ID] = room room.notifyInfoChanged() usersJoinedTotal.Inc() diff --git a/ws/event_disconnected.go b/ws/event_disconnected.go index 7505724e..d4220508 100644 --- a/ws/event_disconnected.go +++ b/ws/event_disconnected.go @@ -18,14 +18,15 @@ func (e *Disconnected) Execute(rooms *Rooms, current ClientInfo) error { } func (e *Disconnected) executeNoError(rooms *Rooms, current ClientInfo) { + roomID := rooms.connected[current.ID] delete(rooms.connected, current.ID) current.Write <- outgoing.CloseWriter{Code: e.Code, Reason: e.Reason} - if current.RoomID == "" { + if roomID == "" { return } - room, ok := rooms.Rooms[current.RoomID] + room, ok := rooms.Rooms[roomID] if !ok { // room may already be removed return @@ -63,12 +64,12 @@ func (e *Disconnected) executeNoError(rooms *Rooms, current ClientInfo) { delete(rooms.connected, member.ID) member.Write <- outgoing.CloseWriter{Code: websocket.CloseNormalClosure, Reason: CloseOwnerLeft} } - rooms.closeRoom(current.RoomID) + rooms.closeRoom(roomID) return } if len(room.Users) == 0 { - rooms.closeRoom(current.RoomID) + rooms.closeRoom(roomID) return } diff --git a/ws/event_hostice.go b/ws/event_hostice.go index 2b5a84e6..cf3b54da 100644 --- a/ws/event_hostice.go +++ b/ws/event_hostice.go @@ -16,13 +16,9 @@ func init() { type HostICE outgoing.P2PMessage func (e *HostICE) Execute(rooms *Rooms, current ClientInfo) error { - if current.RoomID == "" { - return fmt.Errorf("not in a room") - } - - room, ok := rooms.Rooms[current.RoomID] - if !ok { - return fmt.Errorf("room with id %s does not exist", current.RoomID) + room, err := rooms.CurrentRoom(current) + if err != nil { + return err } session, ok := room.Sessions[e.SID] diff --git a/ws/event_hostoffer.go b/ws/event_hostoffer.go index 658a5fed..1bff55e1 100644 --- a/ws/event_hostoffer.go +++ b/ws/event_hostoffer.go @@ -16,13 +16,9 @@ func init() { type HostOffer outgoing.P2PMessage func (e *HostOffer) Execute(rooms *Rooms, current ClientInfo) error { - if current.RoomID == "" { - return fmt.Errorf("not in a room") - } - - room, ok := rooms.Rooms[current.RoomID] - if !ok { - return fmt.Errorf("room with id %s does not exist", current.RoomID) + room, err := rooms.CurrentRoom(current) + if err != nil { + return err } session, ok := room.Sessions[e.SID] diff --git a/ws/event_join.go b/ws/event_join.go index 5930f327..ef75595e 100644 --- a/ws/event_join.go +++ b/ws/event_join.go @@ -16,7 +16,7 @@ type Join struct { } func (e *Join) Execute(rooms *Rooms, current ClientInfo) error { - if current.RoomID != "" { + if rooms.connected[current.ID] != "" { return fmt.Errorf("cannot join room, you are already in one") } @@ -40,6 +40,7 @@ func (e *Join) Execute(rooms *Rooms, current ClientInfo) error { Addr: current.Addr, Write: current.Write, } + rooms.connected[current.ID] = room.ID room.notifyInfoChanged() usersJoinedTotal.Inc() diff --git a/ws/event_name.go b/ws/event_name.go index f7e94d46..d2ebe2b3 100644 --- a/ws/event_name.go +++ b/ws/event_name.go @@ -1,9 +1,5 @@ package ws -import ( - "fmt" -) - func init() { register("name", func() Event { return &Name{} @@ -15,13 +11,9 @@ type Name struct { } func (e *Name) Execute(rooms *Rooms, current ClientInfo) error { - if current.RoomID == "" { - return fmt.Errorf("not in a room") - } - - room, ok := rooms.Rooms[current.RoomID] - if !ok { - return fmt.Errorf("room with id %s does not exist", current.RoomID) + room, err := rooms.CurrentRoom(current) + if err != nil { + return err } room.Users[current.ID].Name = e.UserName diff --git a/ws/event_share.go b/ws/event_share.go index 50b1c0d6..f76296f8 100644 --- a/ws/event_share.go +++ b/ws/event_share.go @@ -1,9 +1,5 @@ package ws -import ( - "fmt" -) - func init() { register("share", func() Event { return &StartShare{} @@ -13,13 +9,9 @@ func init() { type StartShare struct{} func (e *StartShare) Execute(rooms *Rooms, current ClientInfo) error { - if current.RoomID == "" { - return fmt.Errorf("not in a room") - } - - room, ok := rooms.Rooms[current.RoomID] - if !ok { - return fmt.Errorf("room with id %s does not exist", current.RoomID) + room, err := rooms.CurrentRoom(current) + if err != nil { + return err } room.Users[current.ID].Streaming = true diff --git a/ws/event_stop_share.go b/ws/event_stop_share.go index 81c809b8..9d6504a4 100644 --- a/ws/event_stop_share.go +++ b/ws/event_stop_share.go @@ -2,7 +2,6 @@ package ws import ( "bytes" - "fmt" "github.com/screego/server/ws/outgoing" ) @@ -16,13 +15,9 @@ func init() { type StopShare struct{} func (e *StopShare) Execute(rooms *Rooms, current ClientInfo) error { - if current.RoomID == "" { - return fmt.Errorf("not in a room") - } - - room, ok := rooms.Rooms[current.RoomID] - if !ok { - return fmt.Errorf("room with id %s does not exist", current.RoomID) + room, err := rooms.CurrentRoom(current) + if err != nil { + return err } room.Users[current.ID].Streaming = false diff --git a/ws/rooms.go b/ws/rooms.go index e6a86cf1..b9606537 100644 --- a/ws/rooms.go +++ b/ws/rooms.go @@ -20,7 +20,7 @@ func NewRooms(tServer turn.Server, users *auth.Users, conf config.Config) *Rooms return &Rooms{ Rooms: map[string]*Room{}, Incoming: make(chan ClientMessage), - connected: map[xid.ID]bool{}, + connected: map[xid.ID]string{}, turnServer: tServer, users: users, config: conf, @@ -51,7 +51,23 @@ type Rooms struct { users *auth.Users config config.Config r *rand.Rand - connected map[xid.ID]bool + connected map[xid.ID]string +} + +func (r *Rooms) CurrentRoom(info ClientInfo) (*Room, error) { + roomID, ok := r.connected[info.ID] + if !ok { + return nil, fmt.Errorf("not connected") + } + if roomID == "" { + return nil, fmt.Errorf("not in a room") + } + room, ok := r.Rooms[roomID] + if !ok { + return nil, fmt.Errorf("room with id %s does not exist", roomID) + } + + return room, nil } func (r *Rooms) RandUserName() string { @@ -81,7 +97,8 @@ func (r *Rooms) Upgrade(w http.ResponseWriter, req *http.Request) { func (r *Rooms) Start() { for msg := range r.Incoming { - if !msg.SkipConnectedCheck && !r.connected[msg.Info.ID] { + _, connected := r.connected[msg.Info.ID] + if !msg.SkipConnectedCheck && !connected { log.Debug().Interface("event", fmt.Sprintf("%T", msg.Incoming)).Interface("payload", msg.Incoming).Msg("WebSocket Ignore") continue } diff --git a/ws/rooms_test.go b/ws/rooms_test.go new file mode 100644 index 00000000..ff5bdbb8 --- /dev/null +++ b/ws/rooms_test.go @@ -0,0 +1,112 @@ +package ws + +import ( + "encoding/json" + "fmt" + "math/rand" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/xid" +) + +const SERVER = "ws://localhost:5050/stream" + +func TestMultipleClients(t *testing.T) { + t.Skip("only for manual testing") + r := rand.New(rand.NewSource(time.Now().UnixMicro())) + + var wg sync.WaitGroup + + for j := 0; j < 100; j++ { + name := fmt.Sprint(1) + + users := r.Intn(5000) + for i := 0; i < users; i++ { + wg.Add(1) + go func() { + defer wg.Done() + testClient(r.Int63(), name) + }() + if i%100 == 0 { + time.Sleep(10 * time.Millisecond) + } + } + time.Sleep(50 * time.Millisecond) + } + + wg.Wait() +} + +func testClient(i int64, room string) { + r := rand.New(rand.NewSource(i)) + conn, _, err := websocket.DefaultDialer.Dial(SERVER, nil) + if err != nil { + panic(err) + } + go func() { + for { + _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } + }() + defer conn.Close() + + ops := r.Intn(100) + for i := 0; i < ops; i++ { + m := msg(r, room) + err = conn.WriteMessage(websocket.TextMessage, m) + if err != nil { + fmt.Println("err", err) + } + time.Sleep(30 * time.Millisecond) + } +} + +func msg(r *rand.Rand, room string) []byte { + typed := Typed{} + var e Event + switch r.Intn(8) { + case 0: + typed.Type = "clientanswer" + e = &ClientAnswer{SID: xid.New(), Value: nil} + case 1: + typed.Type = "clientice" + e = &ClientICE{SID: xid.New(), Value: nil} + case 2: + typed.Type = "hostice" + e = &HostICE{SID: xid.New(), Value: nil} + case 3: + typed.Type = "hostoffer" + e = &HostOffer{SID: xid.New(), Value: nil} + case 4: + typed.Type = "name" + e = &Name{UserName: "a"} + case 5: + typed.Type = "share" + e = &StartShare{} + case 6: + typed.Type = "stopshare" + e = &StopShare{} + case 7: + typed.Type = "create" + e = &Create{ID: room, CloseOnOwnerLeave: r.Intn(2) == 0, JoinIfExist: r.Intn(2) == 0, Mode: ConnectionSTUN, UserName: "hello"} + } + + b, err := json.Marshal(e) + if err != nil { + panic(err) + } + typed.Payload = json.RawMessage(b) + + b, err = json.Marshal(typed) + if err != nil { + panic(err) + } + return b +}