Skip to content

Commit

Permalink
reduce locks
Browse files Browse the repository at this point in the history
  • Loading branch information
徐志强 committed Jul 23, 2018
1 parent 85c3a7d commit 4c4dc4e
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 66 deletions.
17 changes: 16 additions & 1 deletion framereader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/binary"
"errors"
"net"
"sync"
)

// defaultFrameReader is responsible for read frames
Expand All @@ -14,6 +15,8 @@ type defaultFrameReader struct {
rbuf [16]byte // for header
streamFrameCh map[uint64]chan<- *Frame
ctx context.Context
mu sync.Mutex
doneStreams []uint64
}

// newFrameReader creates a FrameWriter instance to read frames
Expand All @@ -30,6 +33,8 @@ var (

func (dfr *defaultFrameReader) ReadFrame() (*Frame, error) {

dfr.closeStreams()

f, err := dfr.readFrame()
if err != nil {
return f, err
Expand Down Expand Up @@ -118,5 +123,15 @@ func (dfr *defaultFrameReader) readFrame() (*Frame, error) {
}

func (dfr *defaultFrameReader) CloseStream(requestID uint64) {
delete(dfr.streamFrameCh, requestID)
dfr.mu.Lock()
dfr.doneStreams = append(dfr.doneStreams, requestID)
dfr.mu.Unlock()
}

func (dfr *defaultFrameReader) closeStreams() {
dfr.mu.Lock()
for _, requestID := range dfr.doneStreams {
delete(dfr.streamFrameCh, requestID)
}
dfr.mu.Unlock()
}
4 changes: 2 additions & 2 deletions qrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ const (
StreamEndFlag
// NBFlag means it should be handled nonblockly
NBFlag
// CancelFlag cancels a stream (TODO)
// CancelFlag cancels a stream [client to server] (TODO)
CancelFlag
// ErrorFlag indicate client should close the specific stream
// ErrorFlag indicate client should close the specific stream [server to client] (TODO)
ErrorFlag
// CompressFlag indicate packet is compressed (TODO)
CompressFlag
Expand Down
56 changes: 13 additions & 43 deletions serveconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ type serveconn struct {

idx int

mu sync.Mutex
id string // the only mutable

closeCh chan struct{}
untrack uint32 // ony the first call to untrack actually do it, subsequent calls should wait for untrackedCh
untrackedCh chan struct{}

// rwc is the underlying network connection.
// This is never wrapped by other types and is the value given out
Expand All @@ -46,36 +46,7 @@ var ConnectionInfoKey = &contextKey{"qrpc-connection"}
// ConnectionInfo for store info on connection
type ConnectionInfo struct {
SC *serveconn // read only
anything interface{}
}

// Lock reuses the mu of serveconn
func (ci *ConnectionInfo) Lock() {
ci.SC.mu.Lock()
}

// Unlock reuses the mu of serveconn
func (ci *ConnectionInfo) Unlock() {
ci.SC.mu.Unlock()
}

// Store sets anything
func (ci *ConnectionInfo) Store(anything interface{}) {
ci.SC.mu.Lock()
ci.anything = anything
ci.SC.mu.Unlock()
}

// StoreLocked sets anything without lock
func (ci *ConnectionInfo) StoreLocked(anything interface{}) {
ci.anything = anything
}

// Load gets anything
func (ci *ConnectionInfo) Load(anything interface{}) interface{} {
ci.SC.mu.Lock()
defer ci.SC.mu.Unlock()
return ci.anything
Anything interface{}
}

// Server returns the server
Expand Down Expand Up @@ -184,15 +155,11 @@ func (sc *serveconn) SetID(id string) {
if id == "" {
panic("empty id not allowed")
}
sc.mu.Lock()
defer sc.mu.Unlock()
sc.id = id
sc.server.bindID(sc, id)
}

func (sc *serveconn) GetID() string {
sc.mu.Lock()
defer sc.mu.Unlock()
return sc.id
}

Expand Down Expand Up @@ -274,21 +241,24 @@ func (sc *serveconn) writeFrames(timeout int) (err error) {
}

// Close the connection.
func (sc *serveconn) Close() (<-chan struct{}, error) {
func (sc *serveconn) Close() error {

return sc.closeLocked(false)
ok, ch := sc.server.untrack(sc)
if !ok {
<-ch
}
return sc.closeUntracked()

}

func (sc *serveconn) closeLocked(serverLocked bool) (<-chan struct{}, error) {
func (sc *serveconn) closeUntracked() error {
err := sc.rwc.Close()
if err != nil {
return sc.closeCh, err
return err
}
sc.cancelCtx()

sc.server.untrack(sc, serverLocked)
close(sc.closeCh)
close(sc.untrackedCh)

return sc.closeCh, nil
return nil
}
43 changes: 24 additions & 19 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ type Server struct {
bindings []ServerBinding

// manages below
mu sync.Mutex
listeners map[net.Listener]struct{}
doneChan chan struct{}
id2Conn []map[string]*serveconn
mu sync.Mutex
listeners map[net.Listener]struct{}
doneChan chan struct{}

id2Conn []sync.Map
activeConn []sync.Map // for better iterate when write, map[*serveconn]struct{}

wg sync.WaitGroup // wait group for goroutines
Expand All @@ -102,7 +103,7 @@ func NewServer(bindings []ServerBinding) *Server {
bindings: bindings,
listeners: make(map[net.Listener]struct{}),
doneChan: make(chan struct{}),
id2Conn: []map[string]*serveconn{map[string]*serveconn{}, map[string]*serveconn{}},
id2Conn: make([]sync.Map, len(bindings)),
activeConn: make([]sync.Map, len(bindings))}
}

Expand Down Expand Up @@ -219,7 +220,7 @@ func (srv *Server) newConn(rwc net.Conn, idx int) *serveconn {
server: srv,
rwc: rwc,
idx: idx,
closeCh: make(chan struct{}),
untrackedCh: make(chan struct{}),
readFrameCh: make(chan readFrameResult),
writeFrameCh: make(chan writeFrameRequest)}

Expand All @@ -231,33 +232,37 @@ func (srv *Server) newConn(rwc net.Conn, idx int) *serveconn {
func (srv *Server) bindID(sc *serveconn, id string) {

idx := sc.idx
srv.mu.Lock()
defer srv.mu.Unlock()
v, ok := srv.id2Conn[idx][id]
v, ok := srv.id2Conn[idx].Load(id)
vsc := v.(*serveconn)
if ok {
if v == sc {
if vsc == sc {
return
}
ch, _ := v.closeLocked(true)
<-ch
ok, ch := srv.untrack(vsc)
if !ok {
<-ch
}
vsc.closeUntracked()
}

srv.id2Conn[idx][id] = sc
srv.id2Conn[idx].Store(id, sc)
}

func (srv *Server) untrack(sc *serveconn, inLock bool) {
func (srv *Server) untrack(sc *serveconn) (bool, <-chan struct{}) {

idx := sc.idx
if !inLock {
srv.mu.Lock()
defer srv.mu.Unlock()
locked := atomic.CompareAndSwapUint32(&sc.untrack, 0, 1)
if !locked {
return false, sc.untrackedCh
}
idx := sc.idx

if sc.id != "" {
delete(srv.id2Conn[idx], sc.id)
srv.id2Conn[idx].Delete(sc.id)
}
srv.activeConn[idx].Delete(sc)

close(sc.untrackedCh)
return true, sc.untrackedCh
}

func (srv *Server) logf(format string, args ...interface{}) {
Expand Down
1 change: 0 additions & 1 deletion test/qrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ func startServer() {
// time.Sleep(time.Hour)
writer.StartWrite(request.RequestID, HelloRespCmd, 0)

fmt.Println("in server handle")
writer.WriteBytes(append([]byte("hello world "), request.Payload...))
err := writer.EndWrite()
if err != nil {
Expand Down

0 comments on commit 4c4dc4e

Please sign in to comment.