diff --git a/framereader.go b/framereader.go index 6e78654..53191e2 100644 --- a/framereader.go +++ b/framereader.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "net" + "sync" ) // defaultFrameReader is responsible for read frames @@ -14,6 +15,8 @@ type defaultFrameReader struct { rbuf [16]byte // for header streamFrameCh map[uint64]chan<- *Frame ctx context.Context + mu sync.Mutex + doneStreams []uint64 } // newFrameReader creates a FrameWriter instance to read frames @@ -30,6 +33,8 @@ var ( func (dfr *defaultFrameReader) ReadFrame() (*Frame, error) { + dfr.closeStreams() + f, err := dfr.readFrame() if err != nil { return f, err @@ -118,5 +123,15 @@ func (dfr *defaultFrameReader) readFrame() (*Frame, error) { } func (dfr *defaultFrameReader) CloseStream(requestID uint64) { - delete(dfr.streamFrameCh, requestID) + dfr.mu.Lock() + dfr.doneStreams = append(dfr.doneStreams, requestID) + dfr.mu.Unlock() +} + +func (dfr *defaultFrameReader) closeStreams() { + dfr.mu.Lock() + for _, requestID := range dfr.doneStreams { + delete(dfr.streamFrameCh, requestID) + } + dfr.mu.Unlock() } diff --git a/qrpc.go b/qrpc.go index cfd47ec..5c323a9 100644 --- a/qrpc.go +++ b/qrpc.go @@ -13,9 +13,9 @@ const ( StreamEndFlag // NBFlag means it should be handled nonblockly NBFlag - // CancelFlag cancels a stream (TODO) + // CancelFlag cancels a stream [client to server] (TODO) CancelFlag - // ErrorFlag indicate client should close the specific stream + // ErrorFlag indicate client should close the specific stream [server to client] (TODO) ErrorFlag // CompressFlag indicate packet is compressed (TODO) CompressFlag diff --git a/serveconn.go b/serveconn.go index 7713c39..6c1413b 100644 --- a/serveconn.go +++ b/serveconn.go @@ -22,10 +22,10 @@ type serveconn struct { idx int - mu sync.Mutex id string // the only mutable - closeCh chan struct{} + untrack uint32 // ony the first call to untrack actually do it, subsequent calls should wait for untrackedCh + untrackedCh chan struct{} // rwc is the underlying network connection. // This is never wrapped by other types and is the value given out @@ -46,36 +46,7 @@ var ConnectionInfoKey = &contextKey{"qrpc-connection"} // ConnectionInfo for store info on connection type ConnectionInfo struct { SC *serveconn // read only - anything interface{} -} - -// Lock reuses the mu of serveconn -func (ci *ConnectionInfo) Lock() { - ci.SC.mu.Lock() -} - -// Unlock reuses the mu of serveconn -func (ci *ConnectionInfo) Unlock() { - ci.SC.mu.Unlock() -} - -// Store sets anything -func (ci *ConnectionInfo) Store(anything interface{}) { - ci.SC.mu.Lock() - ci.anything = anything - ci.SC.mu.Unlock() -} - -// StoreLocked sets anything without lock -func (ci *ConnectionInfo) StoreLocked(anything interface{}) { - ci.anything = anything -} - -// Load gets anything -func (ci *ConnectionInfo) Load(anything interface{}) interface{} { - ci.SC.mu.Lock() - defer ci.SC.mu.Unlock() - return ci.anything + Anything interface{} } // Server returns the server @@ -184,15 +155,11 @@ func (sc *serveconn) SetID(id string) { if id == "" { panic("empty id not allowed") } - sc.mu.Lock() - defer sc.mu.Unlock() sc.id = id sc.server.bindID(sc, id) } func (sc *serveconn) GetID() string { - sc.mu.Lock() - defer sc.mu.Unlock() return sc.id } @@ -274,21 +241,24 @@ func (sc *serveconn) writeFrames(timeout int) (err error) { } // Close the connection. -func (sc *serveconn) Close() (<-chan struct{}, error) { +func (sc *serveconn) Close() error { - return sc.closeLocked(false) + ok, ch := sc.server.untrack(sc) + if !ok { + <-ch + } + return sc.closeUntracked() } -func (sc *serveconn) closeLocked(serverLocked bool) (<-chan struct{}, error) { +func (sc *serveconn) closeUntracked() error { err := sc.rwc.Close() if err != nil { - return sc.closeCh, err + return err } sc.cancelCtx() - sc.server.untrack(sc, serverLocked) - close(sc.closeCh) + close(sc.untrackedCh) - return sc.closeCh, nil + return nil } diff --git a/server.go b/server.go index 1c60c11..e81bf77 100644 --- a/server.go +++ b/server.go @@ -85,10 +85,11 @@ type Server struct { bindings []ServerBinding // manages below - mu sync.Mutex - listeners map[net.Listener]struct{} - doneChan chan struct{} - id2Conn []map[string]*serveconn + mu sync.Mutex + listeners map[net.Listener]struct{} + doneChan chan struct{} + + id2Conn []sync.Map activeConn []sync.Map // for better iterate when write, map[*serveconn]struct{} wg sync.WaitGroup // wait group for goroutines @@ -102,7 +103,7 @@ func NewServer(bindings []ServerBinding) *Server { bindings: bindings, listeners: make(map[net.Listener]struct{}), doneChan: make(chan struct{}), - id2Conn: []map[string]*serveconn{map[string]*serveconn{}, map[string]*serveconn{}}, + id2Conn: make([]sync.Map, len(bindings)), activeConn: make([]sync.Map, len(bindings))} } @@ -219,7 +220,7 @@ func (srv *Server) newConn(rwc net.Conn, idx int) *serveconn { server: srv, rwc: rwc, idx: idx, - closeCh: make(chan struct{}), + untrackedCh: make(chan struct{}), readFrameCh: make(chan readFrameResult), writeFrameCh: make(chan writeFrameRequest)} @@ -231,33 +232,37 @@ func (srv *Server) newConn(rwc net.Conn, idx int) *serveconn { func (srv *Server) bindID(sc *serveconn, id string) { idx := sc.idx - srv.mu.Lock() - defer srv.mu.Unlock() - v, ok := srv.id2Conn[idx][id] + v, ok := srv.id2Conn[idx].Load(id) + vsc := v.(*serveconn) if ok { - if v == sc { + if vsc == sc { return } - ch, _ := v.closeLocked(true) - <-ch + ok, ch := srv.untrack(vsc) + if !ok { + <-ch + } + vsc.closeUntracked() } - srv.id2Conn[idx][id] = sc + srv.id2Conn[idx].Store(id, sc) } -func (srv *Server) untrack(sc *serveconn, inLock bool) { +func (srv *Server) untrack(sc *serveconn) (bool, <-chan struct{}) { - idx := sc.idx - if !inLock { - srv.mu.Lock() - defer srv.mu.Unlock() + locked := atomic.CompareAndSwapUint32(&sc.untrack, 0, 1) + if !locked { + return false, sc.untrackedCh } + idx := sc.idx if sc.id != "" { - delete(srv.id2Conn[idx], sc.id) + srv.id2Conn[idx].Delete(sc.id) } srv.activeConn[idx].Delete(sc) + close(sc.untrackedCh) + return true, sc.untrackedCh } func (srv *Server) logf(format string, args ...interface{}) { diff --git a/test/qrpc_test.go b/test/qrpc_test.go index 3637f88..1931ebb 100644 --- a/test/qrpc_test.go +++ b/test/qrpc_test.go @@ -74,7 +74,6 @@ func startServer() { // time.Sleep(time.Hour) writer.StartWrite(request.RequestID, HelloRespCmd, 0) - fmt.Println("in server handle") writer.WriteBytes(append([]byte("hello world "), request.Payload...)) err := writer.EndWrite() if err != nil {