From 3832d7945871d4b17f10ead4bc8a0c0e28e86d7e Mon Sep 17 00:00:00 2001 From: Sukun Date: Tue, 5 Mar 2024 21:54:58 +0530 Subject: [PATCH 1/2] webrtc: close data channels cleanly WebRTC data channel close is a synchronous close procedure. We close our outgoing stream, in response the peer is expected to close its outgoing stream. If the peer doesn't close its side of the stream we will end up with a memory leak where the SCTP transport keeps reference to the stream. So we check the number of invalid data channel closures and when this goes over a threshold we close the connection. For our custom purposes we can fork SCTP and implement a unilateral stream Reset which is feasible because we anyway have a state machine on top of the data channels. But for a RFC compliant SCTP implementation, this is how the spec is supposed to work. SCTP stream numbers are limited (uint16) so we do need to reuse the stream ids forcing us to use a synchronous close mechanism. --- p2p/transport/webrtc/connection.go | 26 +++- p2p/transport/webrtc/listener.go | 2 +- p2p/transport/webrtc/stream.go | 170 +++++++++++++++------------ p2p/transport/webrtc/stream_read.go | 4 + p2p/transport/webrtc/stream_test.go | 95 +++++++-------- p2p/transport/webrtc/stream_write.go | 17 +-- p2p/transport/webrtc/transport.go | 2 +- 7 files changed, 177 insertions(+), 139 deletions(-) diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index cd28ec3203..690799884b 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -21,7 +21,10 @@ import ( var _ tpt.CapableConn = &connection{} -const maxAcceptQueueLen = 256 +const ( + maxAcceptQueueLen = 256 + maxInvalidDataChannelClosures = 10 +) type errConnectionTimeout struct{} @@ -51,9 +54,10 @@ type connection struct { remoteKey ic.PubKey remoteMultiaddr ma.Multiaddr - m sync.Mutex - streams map[uint16]*stream - nextStreamID atomic.Int32 + m sync.Mutex + streams map[uint16]*stream + nextStreamID atomic.Int32 + invalidDataChannelClosures atomic.Int32 acceptQueue chan dataChannel @@ -158,7 +162,7 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error dc.Close() return nil, fmt.Errorf("detach channel failed for stream(%d): %w", streamID, err) } - str := newStream(dc, rwc, func() { c.removeStream(streamID) }) + str := newStream(dc, rwc, maxRTT, func() { c.removeStream(streamID) }, c.onDataChannelClose) if err := c.addStream(str); err != nil { str.Reset() return nil, fmt.Errorf("failed to add stream(%d) to connection: %w", streamID, err) @@ -171,7 +175,7 @@ func (c *connection) AcceptStream() (network.MuxedStream, error) { case <-c.ctx.Done(): return nil, c.closeErr case dc := <-c.acceptQueue: - str := newStream(dc.channel, dc.stream, func() { c.removeStream(*dc.channel.ID()) }) + str := newStream(dc.channel, dc.stream, maxRTT,func() { c.removeStream(*dc.channel.ID()) }, c.onDataChannelClose) if err := c.addStream(str); err != nil { str.Reset() return nil, err @@ -207,6 +211,16 @@ func (c *connection) removeStream(id uint16) { delete(c.streams, id) } +func (c *connection) onDataChannelClose(remoteClosed bool) { + if !remoteClosed { + if c.invalidDataChannelClosures.Add(1) > maxInvalidDataChannelClosures { + c.closeOnce.Do(func() { + c.closeWithError(errors.New("peer is not closing datachannels")) + }) + } + } +} + func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { c.closeOnce.Do(func() { diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index af2991ed83..ff6265e60d 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -257,7 +257,7 @@ func (l *listener) setupConnection( if err != nil { return nil, err } - handshakeChannel := newStream(w.HandshakeDataChannel, rwc, func() {}) + handshakeChannel := newStream(w.HandshakeDataChannel, rwc, maxRTT, nil, nil) // we do not yet know A's peer ID so accept any inbound remotePubKey, err := l.transport.noiseHandshake(ctx, w.PeerConnection, handshakeChannel, "", crypto.SHA256, true) if err != nil { diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 45641321dc..7ee3e71873 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -2,6 +2,7 @@ package libp2pwebrtc import ( "errors" + "io" "os" "sync" "time" @@ -35,7 +36,6 @@ const ( // add messages to the send buffer once there is space for 1 full // sized message. bufferedAmountLowThreshold = maxBufferedAmount / 2 - // Proto overhead assumption is 5 bytes protoOverhead = 5 // Varint overhead is assumed to be 2 bytes. This is safe since @@ -45,9 +45,9 @@ const ( // is less than or equal to 2 ^ 14, the varint will not be more than // 2 bytes in length. varintOverhead = 2 - // maxFINACKWait is the maximum amount of time a stream will wait to read - // FIN_ACK before closing the data channel - maxFINACKWait = 10 * time.Second + // maxRTT is an estimate of maximum RTT + // We use this to wait for FIN_ACK and Data Channel Close messages from the peer + maxRTT = 10 * time.Second ) type receiveState uint8 @@ -89,18 +89,15 @@ type stream struct { writeDeadline time.Time controlMessageReaderOnce sync.Once - // controlMessageReaderEndTime is the end time for reading FIN_ACK from the control - // message reader. We cannot rely on SetReadDeadline to do this since that is prone to - // race condition where a previous deadline timer fires after the latest call to - // SetReadDeadline - // See: https://github.com/pion/sctp/pull/290 - controlMessageReaderEndTime time.Time - controlMessageReaderDone sync.WaitGroup - - onDone func() + + onCloseOnce sync.Once + onClose func() + onDataChannelClose func(remoteClosed bool) id uint16 // for logging purposes dataChannel *datachannel.DataChannel closeForShutdownErr error + isClosed bool + rtt time.Duration } var _ network.MuxedStream = &stream{} @@ -108,18 +105,20 @@ var _ network.MuxedStream = &stream{} func newStream( channel *webrtc.DataChannel, rwc datachannel.ReadWriteCloser, - onDone func(), + rtt time.Duration, + onClose func(), + onDataChannelClose func(remoteClosed bool), ) *stream { s := &stream{ - reader: pbio.NewDelimitedReader(rwc, maxMessageSize), - writer: pbio.NewDelimitedWriter(rwc), - writeStateChanged: make(chan struct{}, 1), - id: *channel.ID(), - dataChannel: rwc.(*datachannel.DataChannel), - onDone: onDone, + reader: pbio.NewDelimitedReader(rwc, maxMessageSize), + writer: pbio.NewDelimitedWriter(rwc), + writeStateChanged: make(chan struct{}, 1), + id: *channel.ID(), + dataChannel: rwc.(*datachannel.DataChannel), + onClose: onClose, + onDataChannelClose: onDataChannelClose, + rtt: rtt, } - // released when the controlMessageReader goroutine exits - s.controlMessageReaderDone.Add(1) s.dataChannel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) s.dataChannel.OnBufferedAmountLow(func() { s.notifyWriteStateChanged() @@ -129,55 +128,46 @@ func newStream( } func (s *stream) Close() error { + defer s.signalClose() s.mx.Lock() - isClosed := s.closeForShutdownErr != nil - s.mx.Unlock() - if isClosed { + if s.closeForShutdownErr != nil || s.isClosed { + s.mx.Unlock() return nil } + s.isClosed = true + closeWriteErr := s.closeWriteUnlocked() + closeReadErr := s.closeReadUnlocked() + s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) + s.mx.Unlock() - closeWriteErr := s.CloseWrite() - closeReadErr := s.CloseRead() if closeWriteErr != nil || closeReadErr != nil { s.Reset() return errors.Join(closeWriteErr, closeReadErr) } - - s.mx.Lock() - if s.controlMessageReaderEndTime.IsZero() { - s.controlMessageReaderEndTime = time.Now().Add(maxFINACKWait) - s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) - go func() { - s.controlMessageReaderDone.Wait() - s.cleanup() - }() - } - s.mx.Unlock() return nil } func (s *stream) Reset() error { + defer s.signalClose() s.mx.Lock() - isClosed := s.closeForShutdownErr != nil - s.mx.Unlock() - if isClosed { + defer s.mx.Unlock() + if s.closeForShutdownErr != nil { return nil } - - defer s.cleanup() - cancelWriteErr := s.cancelWrite() - closeReadErr := s.CloseRead() + // reset even if it's closed already + s.isClosed = true + cancelWriteErr := s.cancelWriteUnlocked() + closeReadErr := s.closeReadUnlocked() s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) - return errors.Join(closeReadErr, cancelWriteErr) + return errors.Join(cancelWriteErr, closeReadErr) } func (s *stream) closeForShutdown(closeErr error) { - defer s.cleanup() - + defer s.signalClose() s.mx.Lock() defer s.mx.Unlock() - s.closeForShutdownErr = closeErr + s.isClosed = true s.notifyWriteStateChanged() } @@ -223,22 +213,14 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { } // spawnControlMessageReader is used for processing control messages after the reader is closed. +// It is also responsible for closing the datachannel once the stream is closed func (s *stream) spawnControlMessageReader() { s.controlMessageReaderOnce.Do(func() { // Spawn a goroutine to ensure that we're not holding any locks go func() { - defer s.controlMessageReaderDone.Done() // cleanup the sctp deadline timer goroutine defer s.setDataChannelReadDeadline(time.Time{}) - setDeadline := func() bool { - if s.controlMessageReaderEndTime.IsZero() || time.Now().Before(s.controlMessageReaderEndTime) { - s.setDataChannelReadDeadline(s.controlMessageReaderEndTime) - return true - } - return false - } - // Unblock any Read call waiting on reader.ReadMsg s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) @@ -246,24 +228,39 @@ func (s *stream) spawnControlMessageReader() { // We have the lock any readers blocked on reader.ReadMsg have exited. // From this point onwards only this goroutine will do reader.ReadMsg. - //lint:ignore SA2001 we just want to ensure any exising readers have exited. - // Read calls from this point onwards will exit immediately on checking + // released after write half is closed + s.mx.Lock() + + // Read calls after lock release will exit immediately on checking // s.readState s.readerMx.Unlock() - s.mx.Lock() - defer s.mx.Unlock() - if s.nextMessage != nil { s.processIncomingFlag(s.nextMessage.Flag) s.nextMessage = nil } - for s.closeForShutdownErr == nil && - s.sendState != sendStateDataReceived && s.sendState != sendStateReset { - var msg pb.Message - if !setDeadline() { - return + + var endTime time.Time + var msg pb.Message + for { + // connection closed + if s.closeForShutdownErr != nil { + break + } + // write half completed + if s.sendState == sendStateDataReceived || s.sendState == sendStateReset { + break + } + // deadline exceeded + if !endTime.IsZero() && time.Now().After(endTime) { + break + } + + // The stream is closed. Wait for 1RTT before erroring + if s.isClosed && endTime.IsZero() { + endTime = time.Now().Add(s.rtt) } + s.setDataChannelReadDeadline(endTime) s.mx.Unlock() err := s.reader.ReadMsg(&msg) s.mx.Lock() @@ -274,21 +271,42 @@ func (s *stream) spawnControlMessageReader() { if errors.Is(err, os.ErrDeadlineExceeded) { continue } - return + break } s.processIncomingFlag(msg.Flag) } + + s.mx.Unlock() + remoteClosed := s.closeDataChannel() + if s.onDataChannelClose != nil { + s.onDataChannelClose(remoteClosed) + } }() }) } -func (s *stream) cleanup() { - // Even if we close the datachannel pion keeps a reference to the datachannel around. - // Remove the onBufferedAmountLow callback to ensure that we at least garbage collect - // memory we allocated for this stream. - s.dataChannel.OnBufferedAmountLow(nil) +// closeDataChannel closes the datachannel and waits for 1rtt for remote to close the datachannel +func (s *stream) closeDataChannel() bool { s.dataChannel.Close() - if s.onDone != nil { - s.onDone() + endTime := time.Now().Add(s.rtt) + var msg pb.Message + for { + if time.Now().After(endTime) { + return false + } + s.setDataChannelReadDeadline(endTime) + err := s.reader.ReadMsg(&msg) + if err == nil || errors.Is(err, os.ErrDeadlineExceeded) { + continue + } + return err == io.EOF } } + +func (s *stream) signalClose() { + s.onCloseOnce.Do(func() { + if s.onClose != nil { + s.onClose() + } + }) +} diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index 80d99ea91c..4d7f9b216a 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -103,6 +103,10 @@ func (s *stream) setDataChannelReadDeadline(t time.Time) error { func (s *stream) CloseRead() error { s.mx.Lock() defer s.mx.Unlock() + return s.closeReadUnlocked() +} + +func (s *stream) closeReadUnlocked() error { var err error if s.receiveState == receiveStateReceiving && s.closeForShutdownErr == nil { err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) diff --git a/p2p/transport/webrtc/stream_test.go b/p2p/transport/webrtc/stream_test.go index 52b464c0e4..a63434b000 100644 --- a/p2p/transport/webrtc/stream_test.go +++ b/p2p/transport/webrtc/stream_test.go @@ -102,8 +102,8 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { client, server := getDetachedDataChannels(t) var clientDone, serverDone atomic.Bool - clientStr := newStream(client.dc, client.rwc, func() { clientDone.Store(true) }) - serverStr := newStream(server.dc, server.rwc, func() { serverDone.Store(true) }) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() { clientDone.Store(true) }, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() { serverDone.Store(true) }, nil) // send a foobar from the client n, err := clientStr.Write([]byte("foobar")) @@ -148,8 +148,8 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { func TestStreamPartialReads(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) _, err := serverStr.Write([]byte("foobar")) require.NoError(t, err) @@ -171,8 +171,8 @@ func TestStreamPartialReads(t *testing.T) { func TestStreamSkipEmptyFrames(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) for i := 0; i < 10; i++ { require.NoError(t, serverStr.writer.WriteMsg(&pb.Message{})) @@ -206,7 +206,7 @@ func TestStreamSkipEmptyFrames(t *testing.T) { func TestStreamReadReturnsOnClose(t *testing.T) { client, _ := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) errChan := make(chan error, 1) go func() { _, err := clientStr.Read([]byte{0}) @@ -229,8 +229,8 @@ func TestStreamResets(t *testing.T) { client, server := getDetachedDataChannels(t) var clientDone, serverDone atomic.Bool - clientStr := newStream(client.dc, client.rwc, func() { clientDone.Store(true) }) - serverStr := newStream(server.dc, server.rwc, func() { serverDone.Store(true) }) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() { clientDone.Store(true) }, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() { serverDone.Store(true) }, nil) // send a foobar from the client _, err := clientStr.Write([]byte("foobar")) @@ -265,8 +265,8 @@ func TestStreamResets(t *testing.T) { func TestStreamReadDeadlineAsync(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) timeout := 100 * time.Millisecond if os.Getenv("CI") != "" { @@ -296,8 +296,8 @@ func TestStreamReadDeadlineAsync(t *testing.T) { func TestStreamWriteDeadlineAsync(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) _ = serverStr b := make([]byte, 1024) @@ -326,8 +326,8 @@ func TestStreamWriteDeadlineAsync(t *testing.T) { func TestStreamReadAfterClose(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) serverStr.Close() b := make([]byte, 1) @@ -338,8 +338,8 @@ func TestStreamReadAfterClose(t *testing.T) { client, server = getDetachedDataChannels(t) - clientStr = newStream(client.dc, client.rwc, func() {}) - serverStr = newStream(server.dc, server.rwc, func() {}) + clientStr = newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr = newStream(server.dc, server.rwc, maxRTT, func() {}, nil) serverStr.Reset() b = make([]byte, 1) @@ -351,57 +351,56 @@ func TestStreamReadAfterClose(t *testing.T) { func TestStreamCloseAfterFINACK(t *testing.T) { client, server := getDetachedDataChannels(t) - + rtt := 500 * time.Millisecond done := make(chan bool, 1) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, rtt, func() { done <- true }, func(_ bool) { done <- true }) + serverStr := newStream(server.dc, server.rwc, rtt, func() {}, nil) go func() { - done <- true err := clientStr.Close() assert.NoError(t, err) }() - <-done select { case <-done: - t.Fatalf("Close should not have completed without processing FIN_ACK") case <-time.After(200 * time.Millisecond): + t.Fatalf("Close should call onClose immediately") } b := make([]byte, 1) _, err := serverStr.Read(b) require.Error(t, err) require.ErrorIs(t, err, io.EOF) + serverStr.Close() select { case <-done: - case <-time.After(3 * time.Second): - t.Errorf("Close should have completed") + case <-time.After(rtt): + t.Errorf("data channel close should have completed") } } -// TestStreamFinAckAfterStopSending tests that FIN_ACK is sent even after the write half +// TestStreamFinAckAfterStopSending tests that FIN_ACK is sent after the write half // of the stream is closed. func TestStreamFinAckAfterStopSending(t *testing.T) { client, server := getDetachedDataChannels(t) + rtt := 500 * time.Millisecond done := make(chan bool, 1) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, rtt, func() { done <- true }, func(_ bool) { done <- true }) + serverStr := newStream(server.dc, server.rwc, rtt, func() {}, nil) go func() { clientStr.CloseRead() clientStr.Write([]byte("hello world")) - done <- true err := clientStr.Close() assert.NoError(t, err) }() - <-done + // As the serverStr is not reading clientStr cannot get FIN_ACK select { case <-done: - t.Errorf("Close should not have completed without processing FIN_ACK") case <-time.After(500 * time.Millisecond): + t.Errorf("onClose should have been called immediately") } // serverStr has write half closed and read half open @@ -409,10 +408,10 @@ func TestStreamFinAckAfterStopSending(t *testing.T) { b := make([]byte, 24) _, err := serverStr.Read(b) require.NoError(t, err) - serverStr.Close() // Sends stop_sending, fin + serverStr.Close() // Sends stop_sending, fin and closes datachannel select { case <-done: - case <-time.After(5 * time.Second): + case <-time.After(rtt): t.Fatalf("Close should have completed") } } @@ -422,8 +421,8 @@ func TestStreamConcurrentClose(t *testing.T) { start := make(chan bool, 2) done := make(chan bool, 2) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) - serverStr := newStream(server.dc, server.rwc, func() { done <- true }) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() { done <- true }, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() { done <- true }, nil) go func() { start <- true @@ -448,39 +447,41 @@ func TestStreamConcurrentClose(t *testing.T) { } } +// TestStreamResetAfterClose tests that controlMessageReader skips processing FIN_ACK on Reset func TestStreamResetAfterClose(t *testing.T) { client, _ := getDetachedDataChannels(t) + rtt := 500 * time.Millisecond done := make(chan bool, 2) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + clientStr := newStream(client.dc, client.rwc, rtt, func() { done <- true }, func(_ bool) { done <- true }) clientStr.Close() select { case <-done: - t.Fatalf("Close shouldn't run cleanup immediately") - case <-time.After(500 * time.Millisecond): + case <-time.After(200 * time.Millisecond): + t.Fatalf("Close should run onClose immediately") } clientStr.Reset() select { case <-done: - case <-time.After(2 * time.Second): - t.Fatalf("Reset should run callback immediately") + case <-time.After(rtt + 100*time.Millisecond): + t.Fatalf("close should run ") } } func TestStreamDataChannelCloseOnFINACK(t *testing.T) { client, server := getDetachedDataChannels(t) + rtt := 200 * time.Millisecond done := make(chan bool, 1) - clientStr := newStream(client.dc, client.rwc, func() { done <- true }) - + clientStr := newStream(client.dc, client.rwc, rtt, func() { done <- true }, func(_ bool) { done <- true }) clientStr.Close() select { case <-done: - t.Fatalf("Close shouldn't run cleanup immediately") case <-time.After(500 * time.Millisecond): + t.Fatalf("Close should run onClose immediately") } serverWriter := pbio.NewDelimitedWriter(server.rwc) @@ -488,8 +489,8 @@ func TestStreamDataChannelCloseOnFINACK(t *testing.T) { require.NoError(t, err) select { case <-done: - case <-time.After(2 * time.Second): - t.Fatalf("Callback should be run on reading FIN_ACK") + case <-time.After(rtt + (50 * time.Millisecond)): + t.Fatalf("data channel close should complete within rtt") } b := make([]byte, 100) N := 0 @@ -507,8 +508,8 @@ func TestStreamDataChannelCloseOnFINACK(t *testing.T) { func TestStreamChunking(t *testing.T) { client, server := getDetachedDataChannels(t) - clientStr := newStream(client.dc, client.rwc, func() {}) - serverStr := newStream(server.dc, server.rwc, func() {}) + clientStr := newStream(client.dc, client.rwc, maxRTT, func() {}, nil) + serverStr := newStream(server.dc, server.rwc, maxRTT, func() {}, nil) const N = (16 << 10) + 1000 go func() { diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index 82d4ac287d..7d5e26e921 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -119,16 +119,15 @@ func (s *stream) availableSendSpace() int { return availableSpace } -func (s *stream) cancelWrite() error { - s.mx.Lock() - defer s.mx.Unlock() - +func (s *stream) cancelWriteUnlocked() error { // There's no need to reset the write half if the write half has been closed // successfully or has been reset previously if s.sendState == sendStateDataReceived || s.sendState == sendStateReset { return nil } s.sendState = sendStateReset + // Remove reference to this stream from the datachannel + s.dataChannel.OnBufferedAmountLow(nil) s.notifyWriteStateChanged() if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil { return err @@ -139,16 +138,18 @@ func (s *stream) cancelWrite() error { func (s *stream) CloseWrite() error { s.mx.Lock() defer s.mx.Unlock() + return s.closeWriteUnlocked() +} +func (s *stream) closeWriteUnlocked() error { if s.sendState != sendStateSending { return nil } s.sendState = sendStateDataSent + // Remove reference to this stream from the datachannel + s.dataChannel.OnBufferedAmountLow(nil) s.notifyWriteStateChanged() - if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil { - return err - } - return nil + return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}) } func (s *stream) notifyWriteStateChanged() { diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index ae7cc3a5d7..e097bbf4f4 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -366,7 +366,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement if err != nil { return nil, err } - channel := newStream(w.HandshakeDataChannel, detached, func() {}) + channel := newStream(w.HandshakeDataChannel, detached, maxRTT, nil, nil) remotePubKey, err := t.noiseHandshake(ctx, w.PeerConnection, channel, p, remoteHashFunction, false) if err != nil { From 4fb66c5c550301a65069797e73ce40e339e65be4 Mon Sep 17 00:00:00 2001 From: Sukun Date: Wed, 6 Mar 2024 16:49:49 +0530 Subject: [PATCH 2/2] fix transport integration test --- p2p/transport/webrtc/connection.go | 3 ++- p2p/transport/webrtc/stream.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index 690799884b..8cc24d89a6 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -175,7 +175,7 @@ func (c *connection) AcceptStream() (network.MuxedStream, error) { case <-c.ctx.Done(): return nil, c.closeErr case dc := <-c.acceptQueue: - str := newStream(dc.channel, dc.stream, maxRTT,func() { c.removeStream(*dc.channel.ID()) }, c.onDataChannelClose) + str := newStream(dc.channel, dc.stream, maxRTT, func() { c.removeStream(*dc.channel.ID()) }, c.onDataChannelClose) if err := c.addStream(str); err != nil { str.Reset() return nil, err @@ -215,6 +215,7 @@ func (c *connection) onDataChannelClose(remoteClosed bool) { if !remoteClosed { if c.invalidDataChannelClosures.Add(1) > maxInvalidDataChannelClosures { c.closeOnce.Do(func() { + log.Error("closing connection as peer is not closing datachannels: ", c.RemotePeer(), c.RemoteMultiaddr()) c.closeWithError(errors.New("peer is not closing datachannels")) }) } diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 7ee3e71873..8d21469ce3 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -257,7 +257,7 @@ func (s *stream) spawnControlMessageReader() { } // The stream is closed. Wait for 1RTT before erroring - if s.isClosed && endTime.IsZero() { + if s.sendState == sendStateDataSent && endTime.IsZero() { endTime = time.Now().Add(s.rtt) } s.setDataChannelReadDeadline(endTime)