Skip to content

Commit

Permalink
[kube] add server_id to targets when monitoring exec/portforward conn…
Browse files Browse the repository at this point in the history
…ections (#47829)

This PR adds the target server_id (kubernetes service) when proxy establishes a connection to support kubectl exec and portforward. This allows proxies to terminate early the connection without relying on the upstream to terminate it.
  • Loading branch information
tigrato authored Oct 29, 2024
1 parent 1807dfd commit 9d4b20c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 9 deletions.
22 changes: 17 additions & 5 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2306,7 +2306,7 @@ func (s *clusterSession) close() {
}
}

func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error) {
func (s *clusterSession) monitorConn(conn net.Conn, err error, hostID string) (net.Conn, error) {
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -2321,10 +2321,18 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error)
s.connMonitorCancel(err)
return nil, trace.Wrap(err)
}

lockTargets := s.LockTargets()
// when the target is not a kubernetes_service instance, we don't need to lock it.
// the target could be a remote cluster or a local Kubernetes API server. In both cases,
// hostID is empty.
if hostID != "" {
lockTargets = append(lockTargets, types.LockTarget{
ServerID: hostID,
})
}
err = srv.StartMonitor(srv.MonitorConfig{
LockWatcher: s.parent.cfg.LockWatcher,
LockTargets: s.LockTargets(),
LockTargets: lockTargets,
DisconnectExpiredCert: s.disconnectExpiredCert,
ClientIdleTimeout: s.clientIdleTimeout,
Clock: s.parent.cfg.Clock,
Expand Down Expand Up @@ -2356,12 +2364,16 @@ func (s *clusterSession) getServerMetadata() apievents.ServerMetadata {
}

func (s *clusterSession) Dial(network, addr string) (net.Conn, error) {
return s.monitorConn(s.dial(s.requestContext, network, addr))
var hostID string
conn, err := s.dial(s.requestContext, network, addr, withHostIDCollection(&hostID))
return s.monitorConn(conn, err, hostID)
}

func (s *clusterSession) DialWithContext(opts ...contextDialerOption) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
return s.monitorConn(s.dial(ctx, network, addr, opts...))
var hostID string
conn, err := s.dial(ctx, network, addr, append(opts, withHostIDCollection(&hostID))...)
return s.monitorConn(conn, err, hostID)
}
}

Expand Down
1 change: 0 additions & 1 deletion lib/kube/proxy/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) {
if err != nil {
return nil, err
}

if err := req.Write(conn); err != nil {
conn.Close()
return nil, err
Expand Down
23 changes: 20 additions & 3 deletions lib/kube/proxy/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ func (f *Forwarder) localClusterDialer(kubeClusterName string, opts ...contextDi
ProxyIDs: s.GetProxyIDs(),
})
if err == nil {
opt.collect(s.GetHostID())
return conn, nil
}
errs = append(errs, trace.Wrap(err))
Expand Down Expand Up @@ -423,13 +424,21 @@ func (f *Forwarder) getContextDialerFunc(s *clusterSession, opts ...contextDiale
// contextDialerOptions is a set of options that can be used to filter
// the hosts that the dialer connects to.
type contextDialerOptions struct {
hostID string
hostIDFilter string
collectHostID *string
}

// matches returns true if the host matches the hostID of the dialer options or
// if the dialer hostID is empty.
func (c *contextDialerOptions) matches(hostID string) bool {
return c.hostID == "" || c.hostID == hostID
return c.hostIDFilter == "" || c.hostIDFilter == hostID
}

// collect sets the hostID that the dialer connected to if collectHostID is not nil.
func (c *contextDialerOptions) collect(hostID string) {
if c.collectHostID != nil {
*c.collectHostID = hostID
}
}

// contextDialerOption is a functional option for the contextDialerOptions.
Expand All @@ -442,6 +451,14 @@ type contextDialerOption func(*contextDialerOptions)
// error.
func withTargetHostID(hostID string) contextDialerOption {
return func(o *contextDialerOptions) {
o.hostID = hostID
o.hostIDFilter = hostID
}
}

// withHostIDCollection is a functional option that sets the hostID of the dialer
// to the provided pointer.
func withHostIDCollection(hostID *string) contextDialerOption {
return func(o *contextDialerOptions) {
o.collectHostID = hostID
}
}

0 comments on commit 9d4b20c

Please sign in to comment.