Skip to content

Commit

Permalink
adding additional audit log context around SSH port forwarding (#51327)
Browse files Browse the repository at this point in the history
  • Loading branch information
eriktate authored Jan 22, 2025
1 parent 5b3cfb7 commit f839df8
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 195 deletions.
11 changes: 7 additions & 4 deletions lib/events/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,13 @@ const (
X11ForwardErr = "error"

// Port forwarding event
PortForwardEvent = "port"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"
PortForwardEvent = "port"
PortForwardLocalEvent = "port.local"
PortForwardRemoteEvent = "port.remote"
PortForwardRemoteConnEvent = "port.remote_conn"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"

// AuthAttemptEvent is authentication attempt that either
// succeeded or failed based on event status
Expand Down
6 changes: 6 additions & 0 deletions lib/events/dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ func FromEventFields(fields EventFields) (events.AuditEvent, error) {
e = &events.X11Forward{}
case PortForwardEvent:
e = &events.PortForward{}
case PortForwardLocalEvent:
e = &events.PortForward{}
case PortForwardRemoteEvent:
e = &events.PortForward{}
case PortForwardRemoteConnEvent:
e = &events.PortForward{}
case AuthAttemptEvent:
e = &events.AuthAttempt{}
case SCPEvent:
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1395,19 +1395,19 @@ func (c *ServerContext) GetSessionMetadata() apievents.SessionMetadata {
}
}

func (c *ServerContext) GetPortForwardEvent() apievents.PortForward {
func (c *ServerContext) GetPortForwardEvent(evType, code, addr string) apievents.PortForward {
sconn := c.ConnectionContext.ServerConn
return apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
Type: evType,
Code: code,
},
UserMetadata: c.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: sconn.LocalAddr().String(),
RemoteAddr: sconn.RemoteAddr().String(),
},
Addr: c.DstAddr,
Addr: addr,
Status: apievents.Status{
Success: true,
},
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha
go io.Copy(io.Discard, ch.Stderr())
ch = scx.TrackActivity(ch)

event := scx.GetPortForwardEvent()
event := scx.GetPortForwardEvent(events.PortForwardEvent, events.PortForwardCode, scx.DstAddr)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.log.WithError(err).Error("Failed to emit audit event.")
}
Expand Down Expand Up @@ -1096,7 +1096,7 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r
}
defer conn.Close()

event := scx.GetPortForwardEvent()
event := scx.GetPortForwardEvent(events.PortForwardEvent, events.PortForwardFailureCode, scx.DstAddr)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
scx.WithError(err).Warn("Failed to emit port forward event.")
}
Expand Down
114 changes: 84 additions & 30 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1719,27 +1719,17 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
return
}

startEvent := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode, scx.DstAddr)
s.emitAuditEventWithLog(ctx, &startEvent)

if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
s.Logger.Warnf("Connection problem in direct-tcpip channel: %v %T.", trace.DebugReport(err), err)
errEvent := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode, scx.DstAddr)
s.emitAuditEventWithLog(ctx, &errEvent)
slog.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err)
}

if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
},
UserMetadata: scx.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: scx.ServerConn.LocalAddr().String(),
RemoteAddr: scx.ServerConn.RemoteAddr().String(),
},
Addr: scx.DstAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
s.Logger.WithError(err).Warn("Failed to emit port forward event.")
}
stopEvent := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode, scx.DstAddr)
s.emitAuditEventWithLog(ctx, &stopEvent)
}

// handleSessionRequests handles out of band session requests once the session
Expand Down Expand Up @@ -2073,9 +2063,7 @@ func (s *Server) handleX11Forward(ch ssh.Channel, req *ssh.Request, ctx *srv.Ser
s.replyError(ch, req, err)
err = nil
}
if err := s.EmitAuditEvent(s.ctx, event); err != nil {
s.Logger.WithError(err).Warn("Failed to emit x11-forward event.")
}
s.emitAuditEventWithLog(s.ctx, event)
}()

// check if X11 forwarding is disabled, or if xauth can't be handled.
Expand Down Expand Up @@ -2352,6 +2340,7 @@ func (s *Server) createForwardingContext(ctx context.Context, ccx *sshutils.Conn
if err != nil {
return nil, nil, trace.Wrap(err)
}

listenAddr := sshutils.JoinHostPort(req.Addr, req.Port)
scx.IsTestStub = s.isTestStub
scx.ExecType = teleport.TCPIPForwardRequest
Expand Down Expand Up @@ -2390,13 +2379,72 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
}
scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort)

event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.Logger.WithError(err).Warn("Failed to emit audit event.")
}
if err := sshutils.StartRemoteListener(ctx, scx.ConnectionContext.ServerConn, scx.SrcAddr, listener); err != nil {
return trace.Wrap(err)
}
event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode, scx.SrcAddr)
s.emitAuditEventWithLog(ctx, &event)

// spawn remote forwarding handler to multiplex connections to the forwarded port
go func() {
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode, scx.SrcAddr)
defer s.emitAuditEventWithLog(ctx, &stopEvent)

for {
conn, err := listener.Accept()
if err != nil {
if !utils.IsOKNetworkError(err) {
slog.WarnContext(ctx, "failed to accept connection", "error", err)
}
return
}
logger := slog.With(
"src_addr", scx.SrcAddr,
"remote_addr", conn.RemoteAddr().String(),
)

dstHost, dstPort, err := sshutils.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
conn.Close()
logger.WarnContext(ctx, "failed to parse addr", "error", err)
return
}

req := sshutils.ForwardedTCPIPRequest{
Addr: srcHost,
Port: listenPort,
OrigAddr: dstHost,
OrigPort: dstPort,
}
if err := req.CheckAndSetDefaults(); err != nil {
conn.Close()
logger.WarnContext(ctx, "failed to create forwarded tcpip request", "error", err)
return
}
reqBytes := ssh.Marshal(req)

ch, rch, err := scx.ConnectionContext.ServerConn.OpenChannel(teleport.ChanForwardedTCPIP, reqBytes)
if err != nil {
conn.Close()
logger.WarnContext(ctx, "failed to open channel", "error", err)
continue
}
go ssh.DiscardRequests(rch)
go io.Copy(io.Discard, ch.Stderr())
go func() {
startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode, scx.SrcAddr)
startEvent.RemoteAddr = conn.RemoteAddr().String()
s.emitAuditEventWithLog(ctx, &startEvent)

if err := utils.ProxyConn(ctx, conn, ch); err != nil {
errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode, scx.SrcAddr)
errEvent.RemoteAddr = conn.RemoteAddr().String()
s.emitAuditEventWithLog(ctx, &errEvent)
}

stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode, scx.SrcAddr)
stopEvent.RemoteAddr = conn.RemoteAddr().String()
s.emitAuditEventWithLog(ctx, &stopEvent)
}()
}
}()

// Report addr back to the client.
if r.WantReply {
Expand Down Expand Up @@ -2428,14 +2476,14 @@ func (s *Server) handleCancelTCPIPForwardRequest(ctx context.Context, ccx *sshut
return trace.Wrap(err)
}
defer scx.Close()

listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr)
if !ok {
return trace.NotFound("no remote forwarding listener at %v", scx.SrcAddr)
}
if err := r.Reply(true, nil); err != nil {
s.Logger.Warnf("Failed to reply to %q request: %v", r.Type, err)
}

return trace.Wrap(listener.Close())
}

Expand Down Expand Up @@ -2478,7 +2526,7 @@ func (s *Server) parseSubsystemRequest(req *ssh.Request, ctx *srv.ServerContext)
case r.Name == teleport.SFTPSubsystem:
err := ctx.CheckSFTPAllowed(s.reg)
if err != nil {
s.EmitAuditEvent(context.Background(), &apievents.SFTP{
s.emitAuditEventWithLog(context.Background(), &apievents.SFTP{
Metadata: apievents.Metadata{
Code: events.SFTPDisallowedCode,
Type: events.SFTPEvent,
Expand Down Expand Up @@ -2525,3 +2573,9 @@ func (s *Server) handlePuTTYWinadj(ch ssh.Channel, req *ssh.Request) error {
req.WantReply = false
return nil
}

func (s *Server) emitAuditEventWithLog(ctx context.Context, event apievents.AuditEvent) {
if err := s.EmitAuditEvent(ctx, event); err != nil {
slog.WarnContext(ctx, "Failed to emit event", "type", event.GetType(), "code", event.GetCode())
}
}
10 changes: 0 additions & 10 deletions lib/sshutils/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ package sshutils
import (
"errors"
"io"

"golang.org/x/crypto/ssh"
)

var (
Expand Down Expand Up @@ -59,11 +57,3 @@ func (mc *mockChannel) SendRequest(name string, wantReply bool, payload []byte)
func (mc *mockChannel) Stderr() io.ReadWriter {
return fakeReaderWriter{}
}

type mockSSHConn struct {
mockChan *mockChannel
}

func (mc *mockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
return mc.mockChan, make(<-chan *ssh.Request), nil
}
68 changes: 0 additions & 68 deletions lib/sshutils/tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,9 @@
package sshutils

import (
"context"
"io"
"net"

"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/utils"
)

// DirectTCPIPReq represents the payload of an SSH "direct-tcpip" or
Expand Down Expand Up @@ -72,64 +65,3 @@ func ParseTCPIPForwardReq(data []byte) (*TCPIPForwardReq, error) {
}
return &r, nil
}

type channelOpener interface {
OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error)
}

// StartRemoteListener listens on the given listener and forwards any accepted
// connections over a new "forwarded-tcpip" channel.
func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener) error {
srcHost, srcPort, err := SplitHostPort(srcAddr)
if err != nil {
return trace.Wrap(err)
}

go func() {
for {
conn, err := listener.Accept()
if err != nil {
if !utils.IsOKNetworkError(err) {
log.WithError(err).Warn("failed to accept connection")
}
return
}
logger := log.WithFields(log.Fields{
"srcAddr": srcAddr,
"remoteAddr": conn.RemoteAddr().String(),
})

dstHost, dstPort, err := SplitHostPort(conn.RemoteAddr().String())
if err != nil {
conn.Close()
logger.WithError(err).Warn("failed to parse addr")
return
}

req := ForwardedTCPIPRequest{
Addr: srcHost,
Port: srcPort,
OrigAddr: dstHost,
OrigPort: dstPort,
}
if err := req.CheckAndSetDefaults(); err != nil {
conn.Close()
logger.WithError(err).Warn("failed to create forwarded tcpip request")
return
}
reqBytes := ssh.Marshal(req)

ch, rch, err := sshConn.OpenChannel(teleport.ChanForwardedTCPIP, reqBytes)
if err != nil {
conn.Close()
logger.WithError(err).Warn("failed to open channel")
continue
}
go ssh.DiscardRequests(rch)
go io.Copy(io.Discard, ch.Stderr())
go utils.ProxyConn(ctx, conn, ch)
}
}()

return nil
}
Loading

0 comments on commit f839df8

Please sign in to comment.