diff --git a/framereader.go b/framereader.go index 225ca84..6e78654 100644 --- a/framereader.go +++ b/framereader.go @@ -116,3 +116,7 @@ func (dfr *defaultFrameReader) readFrame() (*Frame, error) { return &Frame{RequestID: requestID, Cmd: cmd, Flags: flags, Payload: payload, ctx: dfr.ctx}, nil } + +func (dfr *defaultFrameReader) CloseStream(requestID uint64) { + delete(dfr.streamFrameCh, requestID) +} diff --git a/qrpc.go b/qrpc.go index 9d7813e..cfd47ec 100644 --- a/qrpc.go +++ b/qrpc.go @@ -15,6 +15,8 @@ const ( NBFlag // CancelFlag cancels a stream (TODO) CancelFlag + // ErrorFlag indicate client should close the specific stream + ErrorFlag // CompressFlag indicate packet is compressed (TODO) CompressFlag // PushFlag mean the frame is pushed from server diff --git a/serveconn.go b/serveconn.go index c556df3..7713c39 100644 --- a/serveconn.go +++ b/serveconn.go @@ -18,6 +18,7 @@ type serveconn struct { cancelCtx context.CancelFunc // ctx is the corresponding context for cancelCtx ctx context.Context + wg sync.WaitGroup // wait group for goroutines idx int @@ -104,22 +105,26 @@ func (sc *serveconn) serve(ctx context.Context) { idx := sc.idx defer func() { + // connection level panic if err := recover(); err != nil { const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] - sc.server.logf("http: panic serving %v: %v\n%s", sc.rwc.RemoteAddr().String(), err, buf) + sc.server.logf("qrpc: panic serving %v: %v\n%s", sc.rwc.RemoteAddr().String(), err, buf) } sc.Close() - cancelCtx() }() binding := sc.server.bindings[idx] sc.reader = newFrameReader(ctx, sc.rwc, binding.DefaultReadTimeout) sc.writer = newFrameWriter(ctx, sc.writeFrameCh) // only used by blocking mode - go sc.readFrames() - go sc.writeFrames(binding.DefaultWriteTimeout) + goFunc(&sc.wg, func() { + sc.readFrames() + }) + goFunc(&sc.wg, func() { + sc.writeFrames(binding.DefaultWriteTimeout) + }) handler := binding.Handler @@ -128,12 +133,20 @@ func (sc *serveconn) serve(ctx context.Context) { case <-ctx.Done(): return case res := <-sc.readFrameCh: + if res.f.Flags&NBFlag == 0 { - handler.ServeQRPC(sc.writer, (*RequestFrame)(res.f)) + func() { + sc.handleRequestPanic(res.f) + handler.ServeQRPC(sc.writer, res.f) + }() res.readMore() } else { res.readMore() - go handler.ServeQRPC(sc.GetWriter(), (*RequestFrame)(res.f)) + goFunc(&sc.wg, func() { + sc.handleRequestPanic(res.f) + handler.ServeQRPC(sc.GetWriter(), res.f) + sc.stopReadStream(res.f.RequestID) + }) } } @@ -141,6 +154,31 @@ func (sc *serveconn) serve(ctx context.Context) { } +func (sc *serveconn) stopReadStream(requestID uint64) { + sc.reader.CloseStream(requestID) +} + +func (sc *serveconn) handleRequestPanic(frame *RequestFrame) { + if err := recover(); err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + sc.server.logf("qrpc: handleRequestPanic %v: %v\n%s", sc.rwc.RemoteAddr().String(), err, buf) + + sc.stopReadStream(frame.RequestID) + // send error frame + writer := sc.GetWriter() + writer.StartWrite(frame.RequestID, frame.Cmd, ErrorFlag) + err = writer.EndWrite() + if err != nil { + sc.Close() + return + } + + } + +} + // SetID sets id for serveconn func (sc *serveconn) SetID(id string) { if id == "" { @@ -168,7 +206,7 @@ func (sc *serveconn) GetWriter() FrameWriter { var ErrInvalidPacket = errors.New("invalid packet") type readFrameResult struct { - f *Frame // valid until readMore is called + f *RequestFrame // valid until readMore is called // readMore should be called once the consumer no longer needs or // retains f. After readMore, f is invalid and more frames can be @@ -203,7 +241,7 @@ func (sc *serveconn) readFrames() (err error) { return err } select { - case sc.readFrameCh <- readFrameResult{f: req, readMore: gateDone}: + case sc.readFrameCh <- readFrameResult{f: (*RequestFrame)(req), readMore: gateDone}: case <-ctx.Done(): return nil } diff --git a/server.go b/server.go index 1b79455..1c60c11 100644 --- a/server.go +++ b/server.go @@ -3,7 +3,6 @@ package qrpc import ( "context" "errors" - "fmt" "net" "sync" "sync/atomic" @@ -71,9 +70,9 @@ func (mux *ServeMux) Handle(cmd Cmd, handler Handler) { // cmd matches the request. func (mux *ServeMux) ServeQRPC(w FrameWriter, r *RequestFrame) { mux.mu.RLock() - fmt.Println("cmd", r.Cmd) h, ok := mux.m[r.Cmd] if !ok { + // TODO error response return } mux.mu.RUnlock() @@ -117,11 +116,11 @@ func (srv *Server) ListenAndServe() error { return err } - goFunc(&srv.wg, func(ln net.Listener, idx int) func() { + goFunc(&srv.wg, func(idx int) func() { return func() { srv.serve(tcpKeepAliveListener{ln.(*net.TCPListener)}, idx) } - }(ln, idx)) + }(idx)) } @@ -182,11 +181,9 @@ func (srv *Server) serve(l tcpKeepAliveListener, idx int) error { tempDelay = 0 c := srv.newConn(rw, idx) - goFunc(&srv.wg, func(c *serveconn) func() { - return func() { - c.serve(serveCtx) - } - }(c)) + goFunc(&srv.wg, func() { + c.serve(serveCtx) + }) } } diff --git a/test/qrpc_test.go b/test/qrpc_test.go index b7415e5..3637f88 100644 --- a/test/qrpc_test.go +++ b/test/qrpc_test.go @@ -24,15 +24,18 @@ func TestHelloWorld(t *testing.T) { fmt.Println(frame) }) - resp, err := conn.Request(HelloCmd, qrpc.NBFlag, []byte("xu")) - if err != nil { - panic(err) - } - frame := resp.GetFrame() - if frame == nil { - panic("nil frame") + for _, flag := range []qrpc.PacketFlag{0, qrpc.NBFlag} { + resp, err := conn.Request(HelloCmd, flag, []byte("xu")) + if err != nil { + panic(err) + } + frame := resp.GetFrame() + if frame == nil { + panic("nil frame") + } + fmt.Println("resp is ", string(frame.Payload)) } - fmt.Println("resp is ", string(frame.Payload)) + } func TestWriter(t *testing.T) {