diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index cd28ec3203..8cc24d89a6 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -21,7 +21,10 @@ import ( var _ tpt.CapableConn = &connection{} -const maxAcceptQueueLen = 256 +const ( + maxAcceptQueueLen = 256 + maxInvalidDataChannelClosures = 10 +) type errConnectionTimeout struct{} @@ -51,9 +54,10 @@ type connection struct { remoteKey ic.PubKey remoteMultiaddr ma.Multiaddr - m sync.Mutex - streams map[uint16]*stream - nextStreamID atomic.Int32 + m sync.Mutex + streams map[uint16]*stream + nextStreamID atomic.Int32 + invalidDataChannelClosures atomic.Int32 acceptQueue chan dataChannel @@ -158,7 +162,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error dc.Close() return nil, fmt.Errorf("detach channel failed for stream(%d): %w", streamID, err) } - str := newStream(dc, rwc, func() { c.removeStream(streamID) }) + str := newStream(dc, rwc, maxRTT, func() { c.removeStream(streamID) }, c.onDataChannelClose) if err := c.addStream(str); err != nil { str.Reset() return nil, fmt.Errorf("failed to add stream(%d) to connection: %w", streamID, err) @@ -171,7 +175,7 @@ func (c *connection) AcceptStream() (network.MuxedStream, error) { case <-c.ctx.Done(): return nil, c.closeErr case dc := <-c.acceptQueue: - str := newStream(dc.channel, dc.stream, func() { c.removeStream(*dc.channel.ID()) }) + str := newStream(dc.channel, dc.stream, maxRTT, func() { c.removeStream(*dc.channel.ID()) }, c.onDataChannelClose) if err := c.addStream(str); err != nil { str.Reset() return nil, err @@ -207,6 +211,17 @@ func (c *connection) removeStream(id uint16) { delete(c.streams, id) } +func (c *connection) onDataChannelClose(remoteClosed bool) { + if !remoteClosed { + if c.invalidDataChannelClosures.Add(1) > maxInvalidDataChannelClosures { + c.closeOnce.Do(func() { + log.Error("closing connection as peer is not closing datachannels: ", c.RemotePeer(), c.RemoteMultiaddr()) + c.closeWithError(errors.New("peer is not closing datachannels")) + }) + } + } +} + func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { c.closeOnce.Do(func() { diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index af2991ed83..ff6265e60d 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -257,7 +257,7 @@ func (l *listener) setupConnection( if err != nil { return nil, err } - handshakeChannel := newStream(w.HandshakeDataChannel, rwc, func() {}) + handshakeChannel := newStream(w.HandshakeDataChannel, rwc, maxRTT, nil, nil) // we do not yet know A's peer ID so accept any inbound remotePubKey, err := l.transport.noiseHandshake(ctx, w.PeerConnection, handshakeChannel, "", crypto.SHA256, true) if err != nil { diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 45641321dc..8d21469ce3 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -2,6 +2,7 @@ package libp2pwebrtc import ( "errors" + "io" "os" "sync" "time" @@ -35,7 +36,6 @@ const ( // add messages to the send buffer once there is space for 1 full // sized message. bufferedAmountLowThreshold = maxBufferedAmount / 2 - // Proto overhead assumption is 5 bytes protoOverhead = 5 // Varint overhead is assumed to be 2 bytes. This is safe since @@ -45,9 +45,9 @@ const ( // is less than or equal to 2 ^ 14, the varint will not be more than // 2 bytes in length. varintOverhead = 2 - // maxFINACKWait is the maximum amount of time a stream will wait to read - // FIN_ACK before closing the data channel - maxFINACKWait = 10 * time.Second + // maxRTT is an estimate of maximum RTT + // We use this to wait for FIN_ACK and Data Channel Close messages from the peer + maxRTT = 10 * time.Second ) type receiveState uint8 @@ -89,18 +89,15 @@ type stream struct { writeDeadline time.Time controlMessageReaderOnce sync.Once - // controlMessageReaderEndTime is the end time for reading FIN_ACK from the control - // message reader. We cannot rely on SetReadDeadline to do this since that is prone to - // race condition where a previous deadline timer fires after the latest call to - // SetReadDeadline - // See: https://github.com/pion/sctp/pull/290 - controlMessageReaderEndTime time.Time - controlMessageReaderDone sync.WaitGroup - - onDone func() + + onCloseOnce sync.Once + onClose func() + onDataChannelClose func(remoteClosed bool) id uint16 // for logging purposes dataChannel *datachannel.DataChannel closeForShutdownErr error + isClosed bool + rtt time.Duration } var _ network.MuxedStream = &stream{} @@ -108,18 +105,20 @@ var _ network.MuxedStream = &stream{} func newStream( channel *webrtc.DataChannel, rwc datachannel.ReadWriteCloser, - onDone func(), + rtt time.Duration, + onClose func(), + onDataChannelClose func(remoteClosed bool), ) *stream { s := &stream{ - reader: pbio.NewDelimitedReader(rwc, maxMessageSize), - writer: pbio.NewDelimitedWriter(rwc), - writeStateChanged: make(chan struct{}, 1), - id: *channel.ID(), - dataChannel: rwc.(*datachannel.DataChannel), - onDone: onDone, + reader: pbio.NewDelimitedReader(rwc, maxMessageSize), + writer: pbio.NewDelimitedWriter(rwc), + writeStateChanged: make(chan struct{}, 1), + id: *channel.ID(), + dataChannel: rwc.(*datachannel.DataChannel), + onClose: onClose, + onDataChannelClose: onDataChannelClose, + rtt: rtt, } - // released when the controlMessageReader goroutine exits - s.controlMessageReaderDone.Add(1) s.dataChannel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) s.dataChannel.OnBufferedAmountLow(func() { s.notifyWriteStateChanged() @@ -129,55 +128,46 @@ func newStream( } func (s *stream) Close() error { + defer s.signalClose() s.mx.Lock() - isClosed := s.closeForShutdownErr != nil - s.mx.Unlock() - if isClosed { + if s.closeForShutdownErr != nil || s.isClosed { + s.mx.Unlock() return nil } + s.isClosed = true + closeWriteErr := s.closeWriteUnlocked() + closeReadErr := s.closeReadUnlocked() + s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) + s.mx.Unlock() - closeWriteErr := s.CloseWrite() - closeReadErr := s.CloseRead() if closeWriteErr != nil || closeReadErr != nil { s.Reset() return errors.Join(closeWriteErr, closeReadErr) } - - s.mx.Lock() - if s.controlMessageReaderEndTime.IsZero() { - s.controlMessageReaderEndTime = time.Now().Add(maxFINACKWait) - s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) - go func() { - s.controlMessageReaderDone.Wait() - s.cleanup() - }() - } - s.mx.Unlock() return nil } func (s *stream) Reset() error { + defer s.signalClose() s.mx.Lock() - isClosed := s.closeForShutdownErr != nil - s.mx.Unlock() - if isClosed { + defer s.mx.Unlock() + if s.closeForShutdownErr != nil { return nil } - - defer s.cleanup() - cancelWriteErr := s.cancelWrite() - closeReadErr := s.CloseRead() + // reset even if it's closed already + s.isClosed = true + cancelWriteErr := s.cancelWriteUnlocked() + closeReadErr := s.closeReadUnlocked() s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) - return errors.Join(closeReadErr, cancelWriteErr) + return errors.Join(cancelWriteErr, closeReadErr) } func (s *stream) closeForShutdown(closeErr error) { - defer s.cleanup() - + defer s.signalClose() s.mx.Lock() defer s.mx.Unlock() - s.closeForShutdownErr = closeErr + s.isClosed = true s.notifyWriteStateChanged() } @@ -223,22 +213,14 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { } // spawnControlMessageReader is used for processing control messages after the reader is closed. +// It is also responsible for closing the datachannel once the stream is closed func (s *stream) spawnControlMessageReader() { s.controlMessageReaderOnce.Do(func() { // Spawn a goroutine to ensure that we're not holding any locks go func() { - defer s.controlMessageReaderDone.Done() // cleanup the sctp deadline timer goroutine defer s.setDataChannelReadDeadline(time.Time{}) - setDeadline := func() bool { - if s.controlMessageReaderEndTime.IsZero() || time.Now().Before(s.controlMessageReaderEndTime) { - s.setDataChannelReadDeadline(s.controlMessageReaderEndTime) - return true - } - return false - } - // Unblock any Read call waiting on reader.ReadMsg s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) @@ -246,24 +228,39 @@ func (s *stream) spawnControlMessageReader() { // We have the lock any readers blocked on reader.ReadMsg have exited. // From this point onwards only this goroutine will do reader.ReadMsg. - //lint:ignore SA2001 we just want to ensure any exising readers have exited. - // Read calls from this point onwards will exit immediately on checking + // released after write half is closed + s.mx.Lock() + + // Read calls after lock release will exit immediately on checking // s.readState s.readerMx.Unlock() - s.mx.Lock() - defer s.mx.Unlock() - if s.nextMessage != nil { s.processIncomingFlag(s.nextMessage.Flag) s.nextMessage = nil } - for s.closeForShutdownErr == nil && - s.sendState != sendStateDataReceived && s.sendState != sendStateReset { - var msg pb.Message - if !setDeadline() { - return + + var endTime time.Time + var msg pb.Message + for { + // connection closed + if s.closeForShutdownErr != nil { + break + } + // write half completed + if s.sendState == sendStateDataReceived || s.sendState == sendStateReset { + break + } + // deadline exceeded + if !endTime.IsZero() && time.Now().After(endTime) { + break + } + + // The stream is closed. Wait for 1RTT before erroring + if s.sendState == sendStateDataSent && endTime.IsZero() { + endTime = time.Now().Add(s.rtt) } + s.setDataChannelReadDeadline(endTime) s.mx.Unlock() err := s.reader.ReadMsg(&msg) s.mx.Lock() @@ -274,21 +271,42 @@ func (s *stream) spawnControlMessageReader() { if errors.Is(err, os.ErrDeadlineExceeded) { continue } - return + break } s.processIncomingFlag(msg.Flag) } + + s.mx.Unlock() + remoteClosed := s.closeDataChannel() + if s.onDataChannelClose != nil { + s.onDataChannelClose(remoteClosed) + } }() }) } -func (s *stream) cleanup() { - // Even if we close the datachannel pion keeps a reference to the datachannel around. - // Remove the onBufferedAmountLow callback to ensure that we at least garbage collect - // memory we allocated for this stream. - s.dataChannel.OnBufferedAmountLow(nil) +// closeDataChannel closes the datachannel and waits for 1rtt for remote to close the datachannel +func (s *stream) closeDataChannel() bool { s.dataChannel.Close() - if s.onDone != nil { - s.onDone() + endTime := time.Now().Add(s.rtt) + var msg pb.Message + for { + if time.Now().After(endTime) { + return false + } + s.setDataChannelReadDeadline(endTime) + err := s.reader.ReadMsg(&msg) + if err == nil || errors.Is(err, os.ErrDeadlineExceeded) { + continue + } + return err == io.EOF } } + +func (s *stream) signalClose() { + s.onCloseOnce.Do(func() { + if s.onClose != nil { + s.onClose() + } + }) +} diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index 80d99ea91c..4d7f9b216a 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -103,6 +103,10 @@ func (s *stream) setDataChannelReadDeadline(t time.Time) error { func (s *stream) CloseRead() error { s.mx.Lock() defer s.mx.Unlock() + return s.closeReadUnlocked() +} + +func (s *stream) closeReadUnlocked() error { var err error if s.receiveState == receiveStateReceiving && s.closeForShutdownErr == nil { err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) diff --git a/p2p/transport/webrtc/stream_test.go b/p2p/transport/webrtc/stream_test.go index 52b464c0e4..a63434b000 100644 --- a/p2p/transport/webrtc/stream_test.go +++ b/p2p/transport/webrtc/stream_test.go @@ -102,8 +102,8 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { client, server := getDetachedDataChannels(t) var clientDone, serverDone atomic.Bool - clientStr := newStream(client.dc, client.rwc, func() { clientDone.Store(true) }) - serverStr := newStream(server.dc, server.rwc, func() { serverDone.Store(true) }) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() { clientDone.Store(true) }, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() { serverDone.Store(true) }, nil) // send a foobar from the client n, err := clientStr.Write([]byte("foobar")) @@ -148,8 +148,8 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { func TestStreamPartialReads(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) _, err := serverStr.Write([]byte("foobar")) require.NoError(t, err) @@ -171,8 +171,8 @@ func TestStreamPartialReads(t *testing.T) { func TestStreamSkipEmptyFrames(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) for i := 0; i < 10; i++ { require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{})) @@ -206,7 +206,7 @@ func TestStreamSkipEmptyFrames(t *testing.T) { func TestStreamReadReturnsOnClose(t *testing.T) { client, _ := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) errChan := make(chan error, 1) go func() { _, err := clientStr.Read([]byte{0}) @@ -229,8 +229,8 @@ func TestStreamResets(t *testing.T) { client, server := getDetachedDataChannels(t) var clientDone, serverDone atomic.Bool - clientStr := newStream(client.dc, client.rwc, func() { clientDone.Store(true) }) - serverStr := newStream(server.dc, server.rwc, func() { serverDone.Store(true) }) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() { clientDone.Store(true) }, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() { serverDone.Store(true) }, nil) // send a foobar from the client _, err := clientStr.Write([]byte("foobar")) @@ -265,8 +265,8 @@ func TestStreamResets(t *testing.T) { func TestStreamReadDeadlineAsync(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) timeout := 100 * time.Millisecond if os.Getenv("CI") != "" { @@ -296,8 +296,8 @@ func TestStreamReadDeadlineAsync(t *testing.T) { func TestStreamWriteDeadlineAsync(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) _ = serverStr b := make([]byte, 1024) @@ -326,8 +326,8 @@ func TestStreamWriteDeadlineAsync(t *testing.T) { func TestStreamReadAfterClose(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) serverStr.Close() b := make([]byte, 1) @@ -338,8 +338,8 @@ func TestStreamReadAfterClose(t *testing.T) { client, server = getDetachedDataChannels(t) - clientStr = newStream(client.dc, client.rwc, func() {}) - serverStr = newStream(server.dc, server.rwc, func() {}) + clientStr = newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr = newStream(server.dc, server.rwc, maxRTT, func() {}, nil) serverStr.Reset() b = make([]byte, 1) @@ -351,57 +351,56 @@ func TestStreamReadAfterClose(t *testing.T) { func TestStreamCloseAfterFINACK(t *testing.T) { client, server := getDetachedDataChannels(t) - + rtt := 500 * time.Millisecond done := make(chan bool, 1) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, rtt, func() { done <- true }, func(_ bool) { done <- true }) + serverStr := newStream(server.dc, server.rwc, rtt, func() {}, nil) go func() { - done <- true err := clientStr.Close() assert.NoError(t, err) }() - <-done select { case <-done: - t.Fatalf("Close should not have completed without processing FIN_ACK") case <-time.After(200 * time.Millisecond): + t.Fatalf("Close should call onClose immediately") } b := make([]byte, 1) _, err := serverStr.Read(b) require.Error(t, err) require.ErrorIs(t, err, io.EOF) + serverStr.Close() select { case <-done: - case <-time.After(3 * time.Second): - t.Errorf("Close should have completed") + case <-time.After(rtt): + t.Errorf("data channel close should have completed") } } -// TestStreamFinAckAfterStopSending tests that FIN_ACK is sent even after the write half +// TestStreamFinAckAfterStopSending tests that FIN_ACK is sent after the write half // of the stream is closed. func TestStreamFinAckAfterStopSending(t *testing.T) { client, server := getDetachedDataChannels(t) + rtt := 500 * time.Millisecond done := make(chan bool, 1) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, rtt, func() { done <- true }, func(_ bool) { done <- true }) + serverStr := newStream(server.dc, server.rwc, rtt, func() {}, nil) go func() { clientStr.CloseRead() clientStr.Write([]byte("hello world")) - done <- true err := clientStr.Close() assert.NoError(t, err) }() - <-done + // As the serverStr is not reading clientStr cannot get FIN_ACK select { case <-done: - t.Errorf("Close should not have completed without processing FIN_ACK") case <-time.After(500 * time.Millisecond): + t.Errorf("onClose should have been called immediately") } // serverStr has write half closed and read half open @@ -409,10 +408,10 @@ func TestStreamFinAckAfterStopSending(t *testing.T) { b := make([]byte, 24) _, err := serverStr.Read(b) require.NoError(t, err) - serverStr.Close() // Sends stop_sending, fin + serverStr.Close() // Sends stop_sending, fin and closes datachannel select { case <-done: - case <-time.After(5 * time.Second): + case <-time.After(rtt): t.Fatalf("Close should have completed") } } @@ -422,8 +421,8 @@ func TestStreamConcurrentClose(t *testing.T) { start := make(chan bool, 2) done := make(chan bool, 2) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) - serverStr := newStream(server.dc, server.rwc, func() { done <- true }) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() { done <- true }, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() { done <- true }, nil) go func() { start <- true @@ -448,39 +447,41 @@ func TestStreamConcurrentClose(t *testing.T) { } } +// TestStreamResetAfterClose tests that controlMessageReader skips processing FIN_ACK on Reset func TestStreamResetAfterClose(t *testing.T) { client, _ := getDetachedDataChannels(t) + rtt := 500 * time.Millisecond done := make(chan bool, 2) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + clientStr := newStream(client.dc, client.rwc, rtt, func() { done <- true }, func(_ bool) { done <- true }) clientStr.Close() select { case <-done: - t.Fatalf("Close shouldn't run cleanup immediately") - case <-time.After(500 * time.Millisecond): + case <-time.After(200 * time.Millisecond): + t.Fatalf("Close should run onClose immediately") } clientStr.Reset() select { case <-done: - case <-time.After(2 * time.Second): - t.Fatalf("Reset should run callback immediately") + case <-time.After(rtt + 100*time.Millisecond): + t.Fatalf("close should run ") } } func TestStreamDataChannelCloseOnFINACK(t *testing.T) { client, server := getDetachedDataChannels(t) + rtt := 200 * time.Millisecond done := make(chan bool, 1) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) - + clientStr := newStream(client.dc, client.rwc, rtt, func() { done <- true }, func(_ bool) { done <- true }) clientStr.Close() select { case <-done: - t.Fatalf("Close shouldn't run cleanup immediately") case <-time.After(500 * time.Millisecond): + t.Fatalf("Close should run onClose immediately") } serverWriter := pbio.NewDelimitedWriter(server.rwc) @@ -488,8 +489,8 @@ func TestStreamDataChannelCloseOnFINACK(t *testing.T) { require.NoError(t, err) select { case <-done: - case <-time.After(2 * time.Second): - t.Fatalf("Callback should be run on reading FIN_ACK") + case <-time.After(rtt + (50 * time.Millisecond)): + t.Fatalf("data channel close should complete within rtt") } b := make([]byte, 100) N := 0 @@ -507,8 +508,8 @@ func TestStreamDataChannelCloseOnFINACK(t *testing.T) { func TestStreamChunking(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) const N = (16 << 10) + 1000 go func() { diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index 82d4ac287d..7d5e26e921 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -119,16 +119,15 @@ func (s *stream) availableSendSpace() int { return availableSpace } -func (s *stream) cancelWrite() error { - s.mx.Lock() - defer s.mx.Unlock() - +func (s *stream) cancelWriteUnlocked() error { // There's no need to reset the write half if the write half has been closed // successfully or has been reset previously if s.sendState == sendStateDataReceived || s.sendState == sendStateReset { return nil } s.sendState = sendStateReset + // Remove reference to this stream from the datachannel + s.dataChannel.OnBufferedAmountLow(nil) s.notifyWriteStateChanged() if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil { return err @@ -139,16 +138,18 @@ func (s *stream) cancelWrite() error { func (s *stream) CloseWrite() error { s.mx.Lock() defer s.mx.Unlock() + return s.closeWriteUnlocked() +} +func (s *stream) closeWriteUnlocked() error { if s.sendState != sendStateSending { return nil } s.sendState = sendStateDataSent + // Remove reference to this stream from the datachannel + s.dataChannel.OnBufferedAmountLow(nil) s.notifyWriteStateChanged() - if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil { - return err - } - return nil + return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}) } func (s *stream) notifyWriteStateChanged() { diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index ae7cc3a5d7..e097bbf4f4 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -366,7 +366,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement if err != nil { return nil, err } - channel := newStream(w.HandshakeDataChannel, detached, func() {}) + channel := newStream(w.HandshakeDataChannel, detached, maxRTT, nil, nil) remotePubKey, err := t.noiseHandshake(ctx, w.PeerConnection, channel, p, remoteHashFunction, false) if err != nil {