Skip to content

Commit

Permalink
netstack: move TCP state to tcp package and cleanup probe
Browse files Browse the repository at this point in the history
TCP probe support is implemented, unnecessarily, across the stack and and tcp
packages. It can live entirely in tcp. Additionally, it is only ever set at
initialization time, so support for dynamically adding/removing the probe isn't
necesary.

The probe is getting in the way of adding debugging for b/339664055.

PiperOrigin-RevId: 695793246
  • Loading branch information
kevinGC authored and gvisor-bot committed Nov 22, 2024
1 parent 9e0e42b commit 86cd4d3
Show file tree
Hide file tree
Showing 18 changed files with 195 additions and 218 deletions.
1 change: 0 additions & 1 deletion pkg/tcpip/stack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ go_library(
"stack_mutex.go",
"stack_options.go",
"state_conn_mutex.go",
"tcp.go",
"transport_demuxer.go",
"transport_endpoints_mutex.go",
"tuple_list.go",
Expand Down
55 changes: 14 additions & 41 deletions pkg/tcpip/stack/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
package stack

import (
"context"
"encoding/binary"
"fmt"
"io"
"math/rand"
"sync/atomic"
"time"

"golang.org/x/time/rate"
Expand Down Expand Up @@ -108,11 +108,6 @@ type Stack struct {

*ports.PortManager

// If not nil, then any new endpoints will have this probe function
// invoked everytime they receive a TCP segment.
// TODO(b/341946753): Restore them when netstack is savable.
tcpProbeFunc atomic.Value `state:"nosave"` // TCPProbeFunc

// clock is used to generate user-visible times.
clock tcpip.Clock

Expand Down Expand Up @@ -2139,41 +2134,6 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra
return nil
}

// AddTCPProbe installs a probe function that will be invoked on every segment
// received by a given TCP endpoint. The probe function is passed a copy of the
// TCP endpoint state before and after processing of the segment.
//
// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints
// created prior to this call will not call the probe function.
//
// Further, installing two different probes back to back can result in some
// endpoints calling the first one and some the second one. There is no
// guarantee provided on which probe will be invoked. Ideally this should only
// be called once per stack.
func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
s.tcpProbeFunc.Store(probe)
}

// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
// otherwise.
func (s *Stack) GetTCPProbe() TCPProbeFunc {
p := s.tcpProbeFunc.Load()
if p == nil {
return nil
}
return p.(TCPProbeFunc)
}

// RemoveTCPProbe removes an installed TCP probe.
//
// NOTE: This only ensures that endpoints created after this call do not
// have a probe attached. Endpoints already created will continue to invoke
// TCP probe.
func (s *Stack) RemoveTCPProbe() {
// This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics.
s.tcpProbeFunc.Store(TCPProbeFunc(nil))
}

// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error {
s.mu.RLock()
Expand Down Expand Up @@ -2452,3 +2412,16 @@ func (s *Stack) IsSaveRestoreEnabled() bool {

return s.saveRestoreEnabled
}

// contextID is this package's type for context.Context.Value keys.
type contextID int

const (
// CtxRestoreStack is a Context.Value key for the stack to be used in restore.
CtxRestoreStack contextID = iota
)

// RestoreStackFromContext returns the stack to be used during restore.
func RestoreStackFromContext(ctx context.Context) *Stack {
return ctx.Value(CtxRestoreStack).(*Stack)
}
1 change: 1 addition & 0 deletions pkg/tcpip/transport/tcp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ go_library(
"segment_state.go",
"segment_unsafe.go",
"snd.go",
"state.go",
"tcp_endpoint_list.go",
"tcp_segment_list.go",
"tcp_segment_refs.go",
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/tcp/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ func (e *Endpoint) handleSegmentsLocked() tcpip.Error {
// +checklocks:e.mu
func (e *Endpoint) probeSegmentLocked() {
if fn := e.probe; fn != nil {
var state stack.TCPEndpointState
var state TCPEndpointState
e.completeStateLocked(&state)
fn(&state)
}
Expand Down
5 changes: 2 additions & 3 deletions pkg/tcpip/transport/tcp/cubic.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"time"

"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)

// effectivelyInfinity is an initialization value used for round-trip times
Expand Down Expand Up @@ -58,7 +57,7 @@ const (
// See: https://tools.ietf.org/html/rfc8312.
// +stateify savable
type cubicState struct {
stack.TCPCubicState
TCPCubicState

// numCongestionEvents tracks the number of congestion events since last
// RTO.
Expand All @@ -72,7 +71,7 @@ type cubicState struct {
func newCubicCC(s *sender) *cubicState {
now := s.ep.stack.Clock().NowMonotonic()
return &cubicState{
TCPCubicState: stack.TCPCubicState{
TCPCubicState: TCPCubicState{
T: now,
Beta: 0.7,
C: 0.4,
Expand Down
8 changes: 4 additions & 4 deletions pkg/tcpip/transport/tcp/cubic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestHyStartAckTrainOK(t *testing.T) {
iss := seqnum.Value(0)
snd := &sender{
ep: ep,
TCPSenderState: stack.TCPSenderState{
TCPSenderState: TCPSenderState{
SndUna: iss + 1,
SndNxt: iss + 1,
Ssthresh: InitialSsthresh,
Expand Down Expand Up @@ -121,7 +121,7 @@ func TestHyStartAckTrainTooSpread(t *testing.T) {
iss := seqnum.Value(0)
snd := &sender{
ep: ep,
TCPSenderState: stack.TCPSenderState{
TCPSenderState: TCPSenderState{
SndUna: iss + 1,
SndNxt: iss + 1,
Ssthresh: InitialSsthresh,
Expand Down Expand Up @@ -190,7 +190,7 @@ func TestHyStartDelayOK(t *testing.T) {
iss := seqnum.Value(0)
snd := &sender{
ep: ep,
TCPSenderState: stack.TCPSenderState{
TCPSenderState: TCPSenderState{
SndUna: iss + 1,
SndNxt: iss + 1,
Ssthresh: InitialSsthresh,
Expand Down Expand Up @@ -241,7 +241,7 @@ func TestHyStartDelay_BelowThresh(t *testing.T) {
iss := seqnum.Value(0)
snd := &sender{
ep: ep,
TCPSenderState: stack.TCPSenderState{
TCPSenderState: TCPSenderState{
SndUna: iss + 1,
SndNxt: iss + 1,
Ssthresh: InitialSsthresh,
Expand Down
21 changes: 9 additions & 12 deletions pkg/tcpip/transport/tcp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,11 @@ func (*Stats) IsEndpointStats() {}
// +stateify savable
type sndQueueInfo struct {
sndQueueMu sync.Mutex `state:"nosave"`
stack.TCPSndBufState
TCPSndBufState
}

// CloneState clones sq into other. It is not thread safe
func (sq *sndQueueInfo) CloneState(other *stack.TCPSndBufState) {
func (sq *sndQueueInfo) CloneState(other *TCPSndBufState) {
other.SndBufSize = sq.SndBufSize
other.SndBufUsed = sq.SndBufUsed
other.SndClosed = sq.SndClosed
Expand Down Expand Up @@ -340,7 +340,7 @@ func (sq *sndQueueInfo) CloneState(other *stack.TCPSndBufState) {
//
// +stateify savable
type Endpoint struct {
stack.TCPEndpointStateInner
TCPEndpointStateInner
stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler

Expand Down Expand Up @@ -377,7 +377,7 @@ type Endpoint struct {
rcvQueueMu sync.Mutex `state:"nosave"`

// +checklocks:rcvQueueMu
stack.TCPRcvBufState
TCPRcvBufState

// rcvMemUsed tracks the total amount of memory in use by received segments
// held in rcvQueue, pendingRcvdSegments and the segment queue. This is used to
Expand Down Expand Up @@ -535,7 +535,7 @@ type Endpoint struct {

// probe if not nil is invoked on every received segment. It is passed
// a copy of the current state of the endpoint.
probe stack.TCPProbeFunc `state:"nosave"`
probe TCPProbeFunc `state:"nosave"`

// The following are only used to assist the restore run to re-connect.
connectingAddress tcpip.Address
Expand Down Expand Up @@ -843,7 +843,7 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto
TransProto: header.TCPProtocolNumber,
},
sndQueueInfo: sndQueueInfo{
TCPSndBufState: stack.TCPSndBufState{
TCPSndBufState: TCPSndBufState{
SndMTU: math.MaxInt32,
},
},
Expand Down Expand Up @@ -904,10 +904,7 @@ func newEndpoint(s *stack.Stack, protocol *protocol, netProto tcpip.NetworkProto
e.maxSynRetries = uint8(synRetries)
}

if p := s.GetTCPProbe(); p != nil {
e.probe = p
}

e.probe = protocol.probe
e.segmentQueue.ep = e

// TODO(https://gvisor.dev/issues/7493): Defer creating the timer until TCP connection becomes
Expand Down Expand Up @@ -3137,9 +3134,9 @@ func (e *Endpoint) maxOptionSize() (size int) {
// used before invoking the probe.
//
// +checklocks:e.mu
func (e *Endpoint) completeStateLocked(s *stack.TCPEndpointState) {
func (e *Endpoint) completeStateLocked(s *TCPEndpointState) {
s.TCPEndpointStateInner = e.TCPEndpointStateInner
s.ID = stack.TCPEndpointID(e.TransportEndpointInfo.ID)
s.ID = TCPEndpointID(e.TransportEndpointInfo.ID)
s.SegTime = e.stack.Clock().NowMonotonic()
s.Receiver = e.rcv.TCPReceiverState
s.Sender = e.snd.TCPSenderState
Expand Down
25 changes: 22 additions & 3 deletions pkg/tcpip/transport/tcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ type protocol struct {
synRetries uint8
dispatcher dispatcher

// probe, if not nil, will be invoked any time an endpoint receives a
// TCP segment.
//
// This is immutable after creation.
probe TCPProbeFunc `state:"nosave"`

// The following secrets are initialized once and stay unchanged after.
seqnumSecret [16]byte
tsOffsetSecret [16]byte
Expand Down Expand Up @@ -520,18 +526,30 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {

// NewProtocol returns a TCP transport protocol with Reno congestion control.
func NewProtocol(s *stack.Stack) stack.TransportProtocol {
return newProtocol(s, ccReno)
return newProtocol(s, ccReno, nil)
}

// NewProtocolProbe returns a TCP transport protocol with Reno congestion
// control and the given probe.
//
// The probe will be invoked on every segment received by TCP endpoints. The
// probe function is passed a copy of the TCP endpoint state before and after
// processing of the segment.
func NewProtocolProbe(probe TCPProbeFunc) func(*stack.Stack) stack.TransportProtocol {
return func(s *stack.Stack) stack.TransportProtocol {
return newProtocol(s, ccReno, probe)
}
}

// NewProtocolCUBIC returns a TCP transport protocol with CUBIC congestion
// control.
//
// TODO(b/345835636): Remove this and make CUBIC the default across the board.
func NewProtocolCUBIC(s *stack.Stack) stack.TransportProtocol {
return newProtocol(s, ccCubic)
return newProtocol(s, ccCubic, nil)
}

func newProtocol(s *stack.Stack, cc string) stack.TransportProtocol {
func newProtocol(s *stack.Stack, cc string, probe TCPProbeFunc) stack.TransportProtocol {
rng := s.SecureRNG()
var seqnumSecret [16]byte
var tsOffsetSecret [16]byte
Expand Down Expand Up @@ -567,6 +585,7 @@ func newProtocol(s *stack.Stack, cc string) stack.TransportProtocol {
recovery: tcpip.TCPRACKLossDetection,
seqnumSecret: seqnumSecret,
tsOffsetSecret: tsOffsetSecret,
probe: probe,
}
p.dispatcher.init(s.InsecureRNG(), runtime.GOMAXPROCS(0))
return &p
Expand Down
3 changes: 1 addition & 2 deletions pkg/tcpip/transport/tcp/rack.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (

"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)

const (
Expand Down Expand Up @@ -47,7 +46,7 @@ const (
//
// +stateify savable
type rackControl struct {
stack.TCPRACKState
TCPRACKState

// exitedRecovery indicates if the connection is exiting loss recovery.
// This flag is set if the sender is leaving the recovery after
Expand Down
5 changes: 2 additions & 3 deletions pkg/tcpip/transport/tcp/rcv.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)

// receiver holds the state necessary to receive TCP segments and turn them
// into a stream of bytes.
//
// +stateify savable
type receiver struct {
stack.TCPReceiverState
TCPReceiverState
ep *Endpoint

// rcvWnd is the non-scaled receive window last advertised to the peer.
Expand All @@ -55,7 +54,7 @@ type receiver struct {
func newReceiver(ep *Endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
ep: ep,
TCPReceiverState: stack.TCPReceiverState{
TCPReceiverState: TCPReceiverState{
RcvNxt: irs + 1,
RcvAcc: irs.Add(rcvWnd + 1),
RcvWndScale: rcvWndScale,
Expand Down
8 changes: 4 additions & 4 deletions pkg/tcpip/transport/tcp/snd.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ type lossRecovery interface {
//
// +stateify savable
type sender struct {
stack.TCPSenderState
TCPSenderState
ep *Endpoint

// lr is the loss recovery algorithm used by the sender.
Expand Down Expand Up @@ -187,7 +187,7 @@ type sender struct {
type rtt struct {
sync.Mutex `state:"nosave"`

stack.TCPRTTState
TCPRTTState
}

// +checklocks:ep.mu
Expand All @@ -199,15 +199,15 @@ func newSender(ep *Endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint

s := &sender{
ep: ep,
TCPSenderState: stack.TCPSenderState{
TCPSenderState: TCPSenderState{
SndWnd: sndWnd,
SndUna: iss + 1,
SndNxt: iss + 1,
RTTMeasureSeqNum: iss + 1,
LastSendTime: ep.stack.Clock().NowMonotonic(),
MaxPayloadSize: maxPayloadSize,
MaxSentAck: irs + 1,
FastRecovery: stack.TCPFastRecoveryState{
FastRecovery: TCPFastRecoveryState{
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
Last: iss,
HighRxt: iss,
Expand Down
Loading

0 comments on commit 86cd4d3

Please sign in to comment.