From f28ad8639045f200efc98a9d7039c17a204ab9c2 Mon Sep 17 00:00:00 2001 From: Ox Cart Date: Tue, 29 Nov 2016 07:28:40 -0600 Subject: [PATCH] Fixed deadlock on concurrent closing --- session.go | 3 ++- session_test.go | 25 +++++++++++++++++++++++++ stream.go | 3 ++- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/session.go b/session.go index ca6e651..5c759c0 100644 --- a/session.go +++ b/session.go @@ -121,13 +121,14 @@ func (s *Session) AcceptStream() (*Stream, error) { // Close is used to close the session and all streams. func (s *Session) Close() (err error) { s.dieLock.Lock() - defer s.dieLock.Unlock() select { case <-s.die: + s.dieLock.Unlock() return errors.New(errBrokenPipe) default: close(s.die) + s.dieLock.Unlock() s.streamLock.Lock() for k := range s.streams { s.streams[k].sessionClose() diff --git a/session_test.go b/session_test.go index d159be7..a9538c4 100644 --- a/session_test.go +++ b/session_test.go @@ -180,6 +180,31 @@ func TestStreamDoubleClose(t *testing.T) { session.Close() } +func TestConcurrentClose(t *testing.T) { + cli, err := net.Dial("tcp", "127.0.0.1:19999") + if err != nil { + t.Fatal(err) + } + session, _ := Client(cli, nil) + numStreams := 100 + streams := make([]*Stream, 0, numStreams) + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < 100; i++ { + stream, _ := session.OpenStream() + streams = append(streams, stream) + } + for _, s := range streams { + stream := s + go func() { + stream.Close() + wg.Done() + }() + } + session.Close() + wg.Wait() +} + func TestTinyReadBuffer(t *testing.T) { cli, err := net.Dial("tcp", "127.0.0.1:19999") if err != nil { diff --git a/stream.go b/stream.go index 0b9f838..34e4abb 100644 --- a/stream.go +++ b/stream.go @@ -126,13 +126,14 @@ func (s *Stream) Write(b []byte) (n int, err error) { // Close implements io.ReadWriteCloser func (s *Stream) Close() error { s.dieLock.Lock() - defer s.dieLock.Unlock() select { case <-s.die: + s.dieLock.Unlock() return errors.New(errBrokenPipe) default: close(s.die) + s.dieLock.Unlock() s.sess.streamClosed(s.id) _, err := s.sess.writeFrame(newFrame(cmdRST, s.id)) return err