Skip to content

Commit

Permalink
add stream and refactor StreamFlag
Browse files Browse the repository at this point in the history
  • Loading branch information
徐志强 committed Jul 26, 2018
1 parent dff3aa0 commit 234b289
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 137 deletions.
54 changes: 42 additions & 12 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type Connection struct {

mu sync.Mutex
respes map[uint64]*response

cs *connstreams
}

// Response for response frames
Expand Down Expand Up @@ -82,7 +84,8 @@ func newConnectionWithPool(addr string, conf ConnectionConfig, p *sync.Pool, f S

c := &Connection{
Conn: conn, conf: conf, subscriber: f, p: p,
writeFrameCh: make(chan writeFrameRequest), respes: make(map[uint64]*response)}
writeFrameCh: make(chan writeFrameRequest), respes: make(map[uint64]*response),
cs: newConnStreams()}

if p == nil {
c.wakeup()
Expand Down Expand Up @@ -124,6 +127,7 @@ func (conn *Connection) GetWriter() FrameWriter {

// StreamWriter is returned by StreamRequest
type StreamWriter interface {
RequestID() uint64
StartWrite(cmd Cmd)
WriteBytes(v []byte) // v is copied in WriteBytes
EndWrite(end bool) error // block until scheduled
Expand All @@ -132,26 +136,30 @@ type StreamWriter interface {
type defaultStreamWriter struct {
w *defaultFrameWriter
requestID uint64
flags PacketFlag
flags FrameFlag
}

// NewStreamWriter creates a streamwriter from StreamWriter
func NewStreamWriter(w FrameWriter, requestID uint64, flags PacketFlag) StreamWriter {
func NewStreamWriter(w FrameWriter, requestID uint64, flags FrameFlag) StreamWriter {
dfr, ok := w.(*defaultFrameWriter)
if !ok {
return nil
}
return newStreamWriter(dfr, requestID, flags)
}

func newStreamWriter(w *defaultFrameWriter, requestID uint64, flags PacketFlag) StreamWriter {
func newStreamWriter(w *defaultFrameWriter, requestID uint64, flags FrameFlag) StreamWriter {
return &defaultStreamWriter{w: w, requestID: requestID, flags: flags}
}

func (dsw *defaultStreamWriter) StartWrite(cmd Cmd) {
dsw.w.StartWrite(dsw.requestID, cmd, dsw.flags)
}

func (dsw *defaultStreamWriter) RequestID() uint64 {
return dsw.requestID
}

func (dsw *defaultStreamWriter) WriteBytes(v []byte) {
dsw.w.WriteBytes(v)
}
Expand All @@ -161,20 +169,28 @@ func (dsw *defaultStreamWriter) EndWrite(end bool) error {
}

// StreamRequest is for streamed request
func (conn *Connection) StreamRequest(cmd Cmd, flags PacketFlag, payload []byte) (Response, StreamWriter, error) {
func (conn *Connection) StreamRequest(cmd Cmd, flags FrameFlag, payload []byte) (Response, StreamWriter, error) {

flags |= StreamFlag | NBFlag
flags = flags.ToStream()
requestID, resp, writer, err := conn.writeFirstFrame(cmd, flags, payload)
if err != nil {
return nil, nil, err
}
return resp, newStreamWriter(writer, requestID, flags), nil
}

// Request send a request frame and returns response frame
// ResetStream resets a stream
func (conn *Connection) ResetStream(requestID uint64) error {
writer := conn.GetWriter()
writer.StartWrite(requestID, 0, StreamRstFlag)
return writer.EndWrite()
}

// Request send a nonstreamed request frame and returns response frame
// error is non nil when write failed
func (conn *Connection) Request(cmd Cmd, flags PacketFlag, payload []byte) (Response, error) {
func (conn *Connection) Request(cmd Cmd, flags FrameFlag, payload []byte) (Response, error) {

flags = flags.ToNonStream()
_, resp, _, err := conn.writeFirstFrame(cmd, flags, payload)

return resp, err
Expand All @@ -185,7 +201,7 @@ var (
ErrNoNewUUID = errors.New("no new uuid available temporary")
)

func (conn *Connection) writeFirstFrame(cmd Cmd, flags PacketFlag, payload []byte) (uint64, Response, *defaultFrameWriter, error) {
func (conn *Connection) writeFirstFrame(cmd Cmd, flags FrameFlag, payload []byte) (uint64, Response, *defaultFrameWriter, error) {
var (
requestID uint64
suc bool
Expand Down Expand Up @@ -261,6 +277,7 @@ func (conn *Connection) Close(err error) error {
}

conn.cancelCtx()
conn.cs.Wait()

var fatal bool
if !(err == context.Canceled || err == context.DeadlineExceeded) {
Expand All @@ -287,12 +304,12 @@ func (conn *Connection) readFrames() {
conn.Close(err)
}()
for {
frame, err = conn.reader.ReadFrame()
frame, err = conn.reader.ReadFrame(conn.cs)
if err != nil {
return
}

if frame.Flags&PushFlag != 0 {
if frame.Flags.IsPush() {
// pushed frame
if conn.subscriber != nil {
conn.subscriber(conn, frame)
Expand Down Expand Up @@ -326,7 +343,20 @@ func (conn *Connection) writeFrames() (err error) {
for {
select {
case res := <-conn.writeFrameCh:
_, err := writer.Write(res.frame)
dfw := res.dfw
flags := dfw.Flags()
requestID := dfw.RequestID()

// skip stream logic if PushFlag set
if !flags.IsPush() {
s := conn.cs.CreateOrGetStream(conn.ctx, requestID, flags)
if !s.AddOutFrame(requestID, flags) {
res.result <- ErrWriteAfterCloseSelf
break
}
}

_, err := writer.Write(dfw.GetWbuf())
res.result <- err
if err != nil {
return err
Expand Down
10 changes: 5 additions & 5 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import (
// all fields are readly only
type Frame struct {
RequestID uint64
Flags PacketFlag
Flags FrameFlag
Cmd Cmd
Payload []byte
frameCh chan *Frame // non nil for the first frame in stream
Stream *stream // non nil for the first frame in stream

// ctx is either the client or server context. It should only
// be modified via copying the whole Request using WithContext.
Expand All @@ -22,7 +22,7 @@ type Frame struct {

// FrameCh get the next frame ch
func (r *Frame) FrameCh() <-chan *Frame {
return r.frameCh
return r.Stream.frameCh
}

// Context returns the request's context. To change the context, use
Expand All @@ -34,8 +34,8 @@ func (r *Frame) FrameCh() <-chan *Frame {
// For outgoing client requests, the context controls cancelation.
//
// For incoming server requests, the context is canceled when the
// client's connection closes, the request is canceled (with HTTP/2),
// or when the ServeHTTP method returns.
// client's connection closes, the request is canceled ,
// or when the ServeQRPC method returns. (TODO)
func (r *Frame) Context() context.Context {
if r.ctx != nil {
return r.ctx
Expand Down
95 changes: 12 additions & 83 deletions framereader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,14 @@ import (
"encoding/binary"
"errors"
"net"
"sync"
)

// defaultFrameReader is responsible for read frames
// should create one instance per connection
type defaultFrameReader struct {
*Reader
rbuf [16]byte // for header
streamFrameCh map[uint64]chan *Frame
ctx context.Context
mu sync.Mutex
doneStreams []uint64
rbuf [16]byte // for header
ctx context.Context
}

// newFrameReader creates a FrameWriter instance to read frames
Expand All @@ -27,83 +23,35 @@ func newFrameReader(ctx context.Context, rwc net.Conn, timeout int) *defaultFram
var (
// ErrInvalidFrameSize when invalid size
ErrInvalidFrameSize = errors.New("invalid frame size")
// ErrStreamFrameMustNB when stream frame not non block
ErrStreamFrameMustNB = errors.New("streaming frame must be non block")
)

func (dfr *defaultFrameReader) ReadFrame() (*Frame, error) {

dfr.closeStreams()
func (dfr *defaultFrameReader) ReadFrame(cs *connstreams) (*Frame, error) {

f, err := dfr.readFrame()
if err != nil {
dfr.shutdown()
return f, err
}

requestID := f.RequestID
flags := f.Flags

// done for non streamed frame
if flags&StreamFlag == 0 {
return f, nil
}

// deal with streamed frames
// ReadFrame is not threadsafe, so below need not be atomic

if flags&NBFlag == 0 {
return nil, ErrStreamFrameMustNB
}
for {
s := cs.CreateOrGetStream(dfr.ctx, requestID, f.Flags)

if flags&StreamEndFlag == 0 {
if dfr.streamFrameCh == nil {
// the first stream for the connection
dfr.streamFrameCh = make(map[uint64]chan *Frame)
}
ch, ok := dfr.streamFrameCh[requestID]
if !ok {
// the first frame for the stream
ch = make(chan *Frame)
dfr.streamFrameCh[requestID] = ch
f.frameCh = ch
if s.TryBind(f) {
return f, nil
}

// continuation frame for the stream
select {
case ch <- f:
return dfr.ReadFrame()
case <-dfr.ctx.Done():
return nil, dfr.ctx.Err()
}
} else {
// the ending frame of the stream
if dfr.streamFrameCh == nil {
// ending frame with no prior stream frames
return f, nil
}
ch, ok := dfr.streamFrameCh[requestID]
ok := s.AddInFrame(f)
if !ok {
// ending frame with no prior stream frames
return f, nil
}
// ending frame for the stream
select {
case ch <- f:
close(ch)
delete(dfr.streamFrameCh, requestID)
return dfr.ReadFrame()
case <-dfr.ctx.Done():
return nil, dfr.ctx.Err()
<-s.Done()
cs.DeleteStream(s, flags&PushFlag != 0)
}
}
}

func (dfr *defaultFrameReader) shutdown() {
for _, ch := range dfr.streamFrameCh {
close(ch)
}
}

func (dfr *defaultFrameReader) readFrame() (*Frame, error) {

header := dfr.rbuf[:]
Expand All @@ -116,7 +64,7 @@ func (dfr *defaultFrameReader) readFrame() (*Frame, error) {
requestID := binary.BigEndian.Uint64(header[4:])
cmdAndFlags := binary.BigEndian.Uint32(header[12:])
cmd := Cmd(cmdAndFlags & 0xffffff)
flags := PacketFlag(cmdAndFlags >> 24)
flags := FrameFlag(cmdAndFlags >> 24)
if size < 12 {
return nil, ErrInvalidFrameSize
}
Expand All @@ -129,22 +77,3 @@ func (dfr *defaultFrameReader) readFrame() (*Frame, error) {

return &Frame{RequestID: requestID, Cmd: cmd, Flags: flags, Payload: payload, ctx: dfr.ctx}, nil
}

func (dfr *defaultFrameReader) CloseStream(requestID uint64) {
dfr.mu.Lock()
dfr.doneStreams = append(dfr.doneStreams, requestID)
dfr.mu.Unlock()
}

func (dfr *defaultFrameReader) closeStreams() {
dfr.mu.Lock()
for _, requestID := range dfr.doneStreams {
ch, ok := dfr.streamFrameCh[requestID]
if !ok {
continue
}
close(ch)
delete(dfr.streamFrameCh, requestID)
}
dfr.mu.Unlock()
}
33 changes: 26 additions & 7 deletions framewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
// defaultFrameWriter is responsible for write frames
// should create one instance per goroutine
type defaultFrameWriter struct {
writeCh chan<- writeFrameRequest
wbuf []byte
ctx context.Context
writeCh chan<- writeFrameRequest
wbuf []byte
requestID uint64
cmd Cmd
flags FrameFlag
ctx context.Context
}

// newFrameWriter creates a FrameWriter instance to write frames
Expand All @@ -18,9 +21,11 @@ func newFrameWriter(ctx context.Context, writeCh chan<- writeFrameRequest) *defa
}

// StartWrite Write the FrameHeader.
func (dfw *defaultFrameWriter) StartWrite(requestID uint64, cmd Cmd, flags PacketFlag) {
func (dfw *defaultFrameWriter) StartWrite(requestID uint64, cmd Cmd, flags FrameFlag) {

// Write the FrameHeader.
dfw.requestID = requestID
dfw.cmd = cmd
dfw.flags = flags
dfw.wbuf = append(dfw.wbuf[:0],
0, // 4 bytes of length, filled in in endWrite
0,
Expand All @@ -38,18 +43,32 @@ func (dfw *defaultFrameWriter) StartWrite(requestID uint64, cmd Cmd, flags Packe
byte(cmd>>16),
byte(cmd>>8),
byte(cmd))

}

func (dfw *defaultFrameWriter) RequestID() uint64 {
return dfw.requestID
}

func (dfw *defaultFrameWriter) Flags() FrameFlag {
return dfw.flags
}

func (dfw *defaultFrameWriter) GetWbuf() []byte {
return dfw.wbuf
}

// EndWrite finishes write frame
func (dfw *defaultFrameWriter) EndWrite() error {

length := len(dfw.wbuf) - 4
_ = append(dfw.wbuf[:0],
byte(length>>24),
byte(length>>16),
byte(length>>8),
byte(length))

wfr := writeFrameRequest{frame: dfw.wbuf, result: make(chan error)}
wfr := writeFrameRequest{dfw: dfw, result: make(chan error)}
select {
case dfw.writeCh <- wfr:
case <-dfw.ctx.Done():
Expand All @@ -66,7 +85,7 @@ func (dfw *defaultFrameWriter) EndWrite() error {

func (dfw *defaultFrameWriter) StreamEndWrite(end bool) error {
if end {
dfw.wbuf[12] |= byte(StreamEndFlag)
dfw.flags = dfw.flags.ToEndStream()
}
return dfw.EndWrite()
}
Expand Down
Loading

0 comments on commit 234b289

Please sign in to comment.