diff --git a/session.go b/session.go index d0c3a13..9de309e 100644 --- a/session.go +++ b/session.go @@ -80,7 +80,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { } // OpenStream is used to create a new stream -func (s *Session) OpenStream() (*Stream, error) { +func (s *Session) OpenStream(metadata ...byte) (*Stream, error) { if s.IsClosed() { return nil, errors.New(errBrokenPipe) } @@ -101,9 +101,11 @@ func (s *Session) OpenStream() (*Stream, error) { } s.nextStreamIDLock.Unlock() - stream := newStream(sid, s.config.MaxFrameSize, s) + stream := newStream(sid, metadata, s.config.MaxFrameSize, s) - if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { + frame := newFrame(cmdSYN, sid) + frame.data = metadata + if _, err := s.writeFrame(frame); err != nil { return nil, errors.Wrap(err, "writeFrame") } @@ -247,7 +249,7 @@ func (s *Session) recvLoop() { case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[f.sid]; !ok { - stream := newStream(f.sid, s.config.MaxFrameSize, s) + stream := newStream(f.sid, append([]byte(nil), f.data...), s.config.MaxFrameSize, s) s.streams[f.sid] = stream select { case s.chAccepts <- stream: diff --git a/session_test.go b/session_test.go index 32fd20b..e754201 100644 --- a/session_test.go +++ b/session_test.go @@ -1,6 +1,7 @@ package smux import ( + "bytes" crand "crypto/rand" "encoding/binary" "fmt" @@ -16,7 +17,7 @@ import ( // setupServer starts new server listening on a random localhost port and // returns address of the server, function to stop the server, new client // connection to this server or an error. -func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) { +func setupServer(tb testing.TB, metadata ...byte) (addr string, stopfunc func(), client net.Conn, err error) { ln, err := net.Listen("tcp", "localhost:0") if err != nil { return "", nil, nil, err @@ -27,7 +28,7 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, tb.Error(err) return } - go handleConnection(conn) + go handleConnection(tb, conn, metadata...) }() addr = ln.Addr().String() conn, err := net.Dial("tcp", addr) @@ -38,10 +39,13 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, return ln.Addr().String(), func() { ln.Close() }, conn, nil } -func handleConnection(conn net.Conn) { +func handleConnection(tb testing.TB, conn net.Conn, metadata ...byte) { session, _ := Server(conn, nil) for { if stream, err := session.AcceptStream(); err == nil { + if !bytes.Equal(metadata, stream.Metadata()) { + tb.Fatal("metadata mimatch") + } go func(s io.ReadWriteCloser) { buf := make([]byte, 65536) for { @@ -58,6 +62,18 @@ func handleConnection(conn net.Conn) { } } +func TestMetadata(t *testing.T) { + metadata := []byte("hello, world") + _, stop, cli, err := setupServer(t, metadata...) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + session.OpenStream(metadata...) + session.Close() +} + func TestEcho(t *testing.T) { _, stop, cli, err := setupServer(t) if err != nil { diff --git a/stream.go b/stream.go index 2a2b82f..d3b3926 100644 --- a/stream.go +++ b/stream.go @@ -14,6 +14,7 @@ import ( // Stream implements net.Conn type Stream struct { id uint32 + metadata []byte rstflag int32 sess *Session buffer bytes.Buffer @@ -27,9 +28,10 @@ type Stream struct { } // newStream initiates a Stream struct -func newStream(id uint32, frameSize int, sess *Session) *Stream { +func newStream(id uint32, metadata []byte, frameSize int, sess *Session) *Stream { s := new(Stream) s.id = id + s.metadata = metadata s.chReadEvent = make(chan struct{}, 1) s.frameSize = frameSize s.sess = sess @@ -42,6 +44,11 @@ func (s *Stream) ID() uint32 { return s.id } +// Metadata returns stream metadata which was provided when opening stream. +func (s *Stream) Metadata() []byte { + return s.metadata +} + // Read implements net.Conn func (s *Stream) Read(b []byte) (n int, err error) { if len(b) == 0 {