From 25497b98baf4b0f0ea30f093f2d345bc4e975549 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Fri, 24 May 2024 16:31:24 -0400 Subject: [PATCH] Reduce tsh binary size (#41743) (#41976) * Remove lib/web dependency from tsh Refactors the terminal stream code in lib/web into a subpackage so that tsh no longer depends on the entire lib/web package. * Remove lib/auth dependency from tsh Refactors invalid credentials errors and helpers into lib/auth/authclient so that tsh no longer imports lib/auth directly. * fix find-replace typos --- buf.yaml | 2 +- build.assets/genproto.sh | 4 +- integration/assist/command_test.go | 7 +- integration/helpers/instance.go | 7 +- lib/auth/authclient/clt.go | 23 +- lib/auth/methods.go | 21 +- lib/benchmark/web.go | 19 +- lib/web/apiserver.go | 3 +- lib/web/apiserver_test.go | 65 +- lib/web/assistant.go | 2 +- lib/web/command.go | 25 +- lib/web/command_test.go | 9 +- lib/web/command_utils.go | 29 +- lib/web/desktop.go | 4 +- lib/web/fuzz_test.go | 4 +- lib/web/mfa_codec.go | 28 +- lib/web/terminal.go | 530 +--------------- lib/web/{ => terminal}/envelope.pb.go | 37 +- lib/web/terminal/terminal.go | 596 ++++++++++++++++++ lib/web/terminal_test.go | 7 +- lib/web/ws_io.go | 10 +- proto/buf.yaml | 2 +- .../lib/web/{ => terminal}/envelope.proto | 2 +- tool/tsh/common/tsh.go | 3 +- 24 files changed, 778 insertions(+), 661 deletions(-) rename lib/web/{ => terminal}/envelope.pb.go (88%) create mode 100644 lib/web/terminal/terminal.go rename proto/teleport/lib/web/{ => terminal}/envelope.proto (93%) diff --git a/buf.yaml b/buf.yaml index 9655c01233119..a60eb8d931282 100644 --- a/buf.yaml +++ b/buf.yaml @@ -32,7 +32,7 @@ lint: - api/proto/teleport/legacy/types/types.proto - api/proto/teleport/legacy/types/wrappers/wrappers.proto - proto/teleport/lib/multiplexer/test/ping.proto - - proto/teleport/lib/web/envelope.proto + - proto/teleport/lib/web/terminal/envelope.proto ignore_only: COMMENT_MESSAGE: - proto/prehog diff --git a/build.assets/genproto.sh b/build.assets/genproto.sh index c16fb5a5962bd..b9b028876c64b 100755 --- a/build.assets/genproto.sh +++ b/build.assets/genproto.sh @@ -46,7 +46,7 @@ main() { --path=api/proto/teleport/legacy/ \ --path=api/proto/teleport/attestation/ \ --path=api/proto/teleport/usageevents/ \ - --path=proto/teleport/lib/web/envelope.proto \ + --path=proto/teleport/lib/web/terminal/envelope.proto \ --exclude-path=api/proto/teleport/legacy/client/proto/event.proto cp -r gogogen/github.com/gravitational/teleport/. . # error out if there's anything outside of github.com/gravitational/teleport @@ -58,7 +58,7 @@ main() { --exclude-path=api/proto/teleport/legacy/ \ --exclude-path=api/proto/teleport/attestation/ \ --exclude-path=api/proto/teleport/usageevents/ \ - --exclude-path=proto/teleport/lib/web/envelope.proto \ + --exclude-path=proto/teleport/lib/web/terminal/envelope.proto \ --exclude-path=proto/prehog/ # Generate event.proto separately because we only want to run it on this diff --git a/integration/assist/command_test.go b/integration/assist/command_test.go index 61dc07027b771..ed855c6416165 100644 --- a/integration/assist/command_test.go +++ b/integration/assist/command_test.go @@ -66,6 +66,7 @@ import ( "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web" + "github.com/gravitational/teleport/lib/web/terminal" ) const ( @@ -452,12 +453,12 @@ type executionWebsocketReader struct { *websocket.Conn } -func (r executionWebsocketReader) Read() (web.Envelope, error) { +func (r executionWebsocketReader) Read() (terminal.Envelope, error) { _, data, err := r.ReadMessage() if err != nil { - return web.Envelope{}, trace.Wrap(err) + return terminal.Envelope{}, trace.Wrap(err) } - var envelope web.Envelope + var envelope terminal.Envelope return envelope, trace.Wrap(proto.Unmarshal(data, &envelope)) } diff --git a/integration/helpers/instance.go b/integration/helpers/instance.go index b6f19db69b79b..455557516ef8a 100644 --- a/integration/helpers/instance.go +++ b/integration/helpers/instance.go @@ -67,6 +67,7 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web" websession "github.com/gravitational/teleport/lib/web/session" + "github.com/gravitational/teleport/lib/web/terminal" ) const ( @@ -1532,7 +1533,7 @@ func CreateWebSession(proxyHost, user, password string) (*web.CreateSessionRespo // SSH establishes an SSH connection via the web api in the same manner that // the web UI does. The returned [web.TerminalStream] should be used as stdin/stdout // for the session. -func (w *WebClient) SSH(termReq web.TerminalRequest) (*web.TerminalStream, error) { +func (w *WebClient) SSH(termReq web.TerminalRequest) (*terminal.Stream, error) { u := url.URL{ Host: w.i.Web, Scheme: client.WSS, @@ -1574,7 +1575,7 @@ func (w *WebClient) SSH(termReq web.TerminalRequest) (*web.TerminalStream, error return nil, trace.BadParameter("unexpected websocket message; got %d want %d", ty, websocket.BinaryMessage) } - var env web.Envelope + var env terminal.Envelope err = proto.Unmarshal(raw, &env) if err != nil { return nil, trace.Wrap(err) @@ -1590,7 +1591,7 @@ func (w *WebClient) SSH(termReq web.TerminalRequest) (*web.TerminalStream, error return nil, trace.Wrap(err) } - stream := web.NewTerminalStream(context.Background(), ws, utils.NewLoggerForTests()) + stream := terminal.NewStream(context.Background(), terminal.StreamConfig{WS: ws}) return stream, nil } diff --git a/lib/auth/authclient/clt.go b/lib/auth/authclient/clt.go index 6f9f4638e385c..149139ed7dbcf 100644 --- a/lib/auth/authclient/clt.go +++ b/lib/auth/authclient/clt.go @@ -20,6 +20,7 @@ package authclient import ( "context" + "errors" "fmt" "net" "net/url" @@ -68,10 +69,24 @@ const ( MissingNamespaceError = "missing required parameter: namespace" ) -// ErrNoMFADevices is returned when an MFA ceremony is performed without possible devices to -// complete the challenge with. -var ErrNoMFADevices = &trace.AccessDeniedError{ - Message: "MFA is required to access this resource but user has no MFA devices; use 'tsh mfa add' to register MFA devices", +var ( + // ErrNoMFADevices is returned when an MFA ceremony is performed without possible devices to + // complete the challenge with. + ErrNoMFADevices = &trace.AccessDeniedError{ + Message: "MFA is required to access this resource but user has no MFA devices; use 'tsh mfa add' to register MFA devices", + } + // InvalidUserPassError is the error for when either the provided username or + // password is incorrect. + InvalidUserPassError = &trace.AccessDeniedError{Message: "invalid username or password"} + // InvalidUserPass2FError is the error for when either the provided username, + // password, or second factor is incorrect. + InvalidUserPass2FError = &trace.AccessDeniedError{Message: "invalid username, password or second factor"} +) + +// IsInvalidLocalCredentialError checks if an error resulted from an incorrect username, +// password, or second factor. +func IsInvalidLocalCredentialError(err error) bool { + return errors.Is(err, InvalidUserPassError) || errors.Is(err, InvalidUserPass2FError) } // HostFQDN consists of host UUID and cluster name joined via '.'. diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 54873e5ca2a11..5389876df9dfa 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -191,17 +191,10 @@ func (a *Server) emitAuthAuditEvent(ctx context.Context, props authAuditProps) e var ( // authenticateHeadlessError is the generic error returned for failed headless // authentication attempts. - authenticateHeadlessError = trace.AccessDenied("headless authentication failed") + authenticateHeadlessError = &trace.AccessDeniedError{Message: "headless authentication failed"} // authenticateWebauthnError is the generic error returned for failed WebAuthn // authentication attempts. - authenticateWebauthnError = trace.AccessDenied("invalid Webauthn response") - // invalidUserPassError is the error for when either the provided username or - // password is incorrect. - invalidUserPassError = trace.AccessDenied("invalid username or password") - // invalidUserpass2FError is the error for when either the provided username, - // password, or second factor is incorrect. - invalidUserPass2FError = trace.AccessDenied("invalid username, password or second factor") - + authenticateWebauthnError = &trace.AccessDeniedError{Message: "invalid Webauthn response"} // errSSOUserLocalAuth is issued for SSO users attempting local authentication // or related actions (like trying to set a password) // Kept purposefully vague, as such actions don't happen during normal @@ -209,12 +202,6 @@ var ( errSSOUserLocalAuth = &trace.AccessDeniedError{Message: "invalid credentials"} ) -// IsInvalidLocalCredentialError checks if an error resulted from an incorrect username, -// password, or second factor. -func IsInvalidLocalCredentialError(err error) bool { - return errors.Is(err, invalidUserPassError) || errors.Is(err, invalidUserPass2FError) -} - type verifyMFADeviceLocksParams struct { // Checker used to verify locks. // Optional, created via a [UserState] fetch if nil. @@ -352,7 +339,7 @@ func (a *Server) authenticateUserInternal(ctx context.Context, req authclient.Au } return res.mfaDev, nil } - authErr = invalidUserPass2FError + authErr = authclient.InvalidUserPass2FError } if authenticateFn != nil { err := a.WithUserLock(user, func() error { @@ -420,7 +407,7 @@ func (a *Server) authenticateUserInternal(ctx context.Context, req authclient.Au // provide obscure message on purpose, while logging the real // error server side log.Debugf("User %v failed to authenticate: %v.", user, err) - return nil, "", trace.Wrap(invalidUserPassError) + return nil, "", trace.Wrap(authclient.InvalidUserPassError) } return nil, user, nil } diff --git a/lib/benchmark/web.go b/lib/benchmark/web.go index 7a52db72e7b6e..d783010d2afef 100644 --- a/lib/benchmark/web.go +++ b/lib/benchmark/web.go @@ -37,7 +37,7 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/teleport/lib/web" + "github.com/gravitational/teleport/lib/web/terminal" ) // WebSSHBenchmark is a benchmark suite that connects to the configured @@ -172,9 +172,22 @@ func getServers(ctx context.Context, tc *client.TeleportClient) ([]types.Server, return resources, nil } +// TerminalRequest describes a request to create a web-based terminal +// to a remote SSH server. +type TerminalRequest struct { + // Server describes a server to connect to (serverId|hostname[:port]). + Server string `json:"server_id"` + + // Login is Linux username to connect as. + Login string `json:"login"` + + // Term is the initial PTY size. + Term session.TerminalParams `json:"term"` +} + // connectToHost opens an SSH session to the target host via the Proxy web api. func connectToHost(ctx context.Context, tc *client.TeleportClient, webSession *webSession, host string) (io.ReadWriteCloser, error) { - req := web.TerminalRequest{ + req := TerminalRequest{ Server: host, Login: tc.HostLogin, Term: session.TerminalParams{ @@ -220,7 +233,7 @@ func connectToHost(ctx context.Context, tc *client.TeleportClient, webSession *w return nil, trace.BadParameter("unexpected websocket message received %d", ty) } - stream := web.NewTerminalStream(ctx, ws, utils.NewLogger()) + stream := terminal.NewStream(ctx, terminal.StreamConfig{WS: ws}) return stream, trace.Wrap(err) } diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 9b215394f6265..a285ea8e1b8f9 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -92,6 +92,7 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/app" websession "github.com/gravitational/teleport/lib/web/session" + "github.com/gravitational/teleport/lib/web/terminal" "github.com/gravitational/teleport/lib/web/ui" ) @@ -3807,7 +3808,7 @@ func (h *Handler) writeErrToWebSocket(ws *websocket.Conn, err error) { if err == nil { return } - errEnvelope := Envelope{ + errEnvelope := terminal.Envelope{ Type: defaults.WebsocketError, Payload: trace.UserMessage(err), } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 354420a894003..bf8b1f6e3402c 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -130,6 +130,7 @@ import ( "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" websession "github.com/gravitational/teleport/lib/web/session" + "github.com/gravitational/teleport/lib/web/terminal" "github.com/gravitational/teleport/lib/web/ui" ) @@ -1438,7 +1439,7 @@ func TestResolveServerHostPort(t *testing.T) { } } -func isFileTransferRequest(e *Envelope) bool { +func isFileTransferRequest(e *terminal.Envelope) bool { if e.GetType() != defaults.WebsocketAudit { return false } @@ -1449,7 +1450,7 @@ func isFileTransferRequest(e *Envelope) bool { return ef.GetType() == string(srv.FileTransferUpdate) } -func isFileTransferDecision(e *Envelope) bool { +func isFileTransferDecision(e *terminal.Envelope) bool { if e.GetType() != defaults.WebsocketAudit { return false } @@ -1460,7 +1461,7 @@ func isFileTransferDecision(e *Envelope) bool { return ef.GetType() == string(srv.FileTransferApproved) } -func getRequestId(e *Envelope) (string, error) { +func getRequestId(e *terminal.Envelope) (string, error) { var ef events.EventFields if err := json.Unmarshal([]byte(e.GetPayload()), &ef); err != nil { return "", err @@ -1473,7 +1474,7 @@ func TestFileTransferEvents(t *testing.T) { s := newWebSuiteWithConfig(t, webSuiteConfig{disableDiskBasedRecording: true}) errs := make(chan error, 2) - readLoop := func(ctx context.Context, ws *websocket.Conn, ch chan<- *Envelope) { + readLoop := func(ctx context.Context, ws *websocket.Conn, ch chan<- *terminal.Envelope) { for { select { case <-ctx.Done(): @@ -1490,7 +1491,7 @@ func TestFileTransferEvents(t *testing.T) { errs <- trace.BadParameter("expected binary message, got %v", typ) return } - var envelope Envelope + var envelope terminal.Envelope if err := proto.Unmarshal(b, &envelope); err != nil { errs <- trace.Wrap(err) return @@ -1507,7 +1508,7 @@ func TestFileTransferEvents(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - wsMessages := make(chan *Envelope) + wsMessages := make(chan *terminal.Envelope) go readLoop(ctx, ws, wsMessages) // Create file transfer event @@ -1517,7 +1518,7 @@ func TestFileTransferEvents(t *testing.T) { }) require.NoError(t, err) - envelope := &Envelope{ + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketFileTransferRequest, Payload: string(data), @@ -1543,7 +1544,7 @@ func TestFileTransferEvents(t *testing.T) { "approved": true, }) require.NoError(t, err) - envelope := &Envelope{ + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketFileTransferDecision, Payload: string(data), @@ -1702,7 +1703,7 @@ func TestResizeTerminal(t *testing.T) { sid := session.NewID() errs := make(chan error, 2) - readLoop := func(ctx context.Context, ws *websocket.Conn, ch chan<- *Envelope) { + readLoop := func(ctx context.Context, ws *websocket.Conn, ch chan<- *terminal.Envelope) { for { select { case <-ctx.Done(): @@ -1719,7 +1720,7 @@ func TestResizeTerminal(t *testing.T) { errs <- trace.BadParameter("expected binary message, got %v", typ) return } - var envelope Envelope + var envelope terminal.Envelope if err := proto.Unmarshal(b, &envelope); err != nil { errs <- trace.Wrap(err) return @@ -1750,8 +1751,8 @@ func TestResizeTerminal(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - ws1Messages := make(chan *Envelope) - ws2Messages := make(chan *Envelope) + ws1Messages := make(chan *terminal.Envelope) + ws2Messages := make(chan *terminal.Envelope) go readLoop(ctx, ws1, ws1Messages) go readLoop(ctx, ws2, ws2Messages) @@ -1804,7 +1805,7 @@ t1ready: events.TerminalSize: params.Serialize(), }) require.NoError(t, err) - envelope := &Envelope{ + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketResize, Payload: string(data), @@ -1830,7 +1831,7 @@ t1ready: } } -func isResizeEventEnvelope(e *Envelope) bool { +func isResizeEventEnvelope(e *terminal.Envelope) bool { if e.GetType() != defaults.WebsocketAudit { return false } @@ -2017,7 +2018,7 @@ func TestTerminalRouting(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { tt.wsCloseAssertion(t, ws.Close()) }) - stream := NewTerminalStream(s.ctx, ws, utils.NewLoggerForTests()) + stream := terminal.NewStream(s.ctx, terminal.StreamConfig{WS: ws}) // here we intentionally run a command where the output we're looking // for is not present in the command itself @@ -2141,7 +2142,7 @@ func TestTerminalRequireSessionMFA(t *testing.T) { webauthnResBytes, err := json.Marshal(wantypes.CredentialAssertionResponseFromProto(res.GetWebauthn())) require.NoError(t, err) - envelope := &Envelope{ + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketWebauthnChallenge, Payload: string(webauthnResBytes), @@ -2168,15 +2169,15 @@ func TestTerminalRequireSessionMFA(t *testing.T) { ty, raw, err := ws.ReadMessage() require.Nil(t, err) require.Equal(t, websocket.BinaryMessage, ty) - var env Envelope + var env terminal.Envelope require.Nil(t, proto.Unmarshal(raw, &env)) chals := &client.MFAAuthenticateChallenge{} require.Nil(t, json.Unmarshal([]byte(env.Payload), &chals)) // Send response over ws. - stream := NewTerminalStream(ctx, ws, utils.NewLoggerForTests()) - err = stream.ws.WriteMessage(websocket.BinaryMessage, tc.getChallengeResponseBytes(chals, dev)) + stream := terminal.NewStream(ctx, terminal.StreamConfig{WS: ws}) + err = stream.WriteMessage(websocket.BinaryMessage, tc.getChallengeResponseBytes(chals, dev)) require.Nil(t, err) // Test we can write. @@ -2366,7 +2367,7 @@ func TestWebAgentForward(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) - stream := NewTerminalStream(s.ctx, ws, utils.NewLoggerForTests()) + stream := terminal.NewStream(s.ctx, terminal.StreamConfig{WS: ws}) _, err = io.WriteString(stream, "echo $SSH_AUTH_SOCK\r\n") require.NoError(t, err) @@ -2478,7 +2479,7 @@ func TestCloseConnectionsOnLogout(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) - stream := NewTerminalStream(s.ctx, ws, utils.NewLoggerForTests()) + stream := terminal.NewStream(s.ctx, terminal.StreamConfig{WS: ws}) // to make sure we have a session _, err = io.WriteString(stream, "expr 137 + 39\r\n") @@ -7563,7 +7564,7 @@ func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOp return nil, nil, trace.Wrap(err) } require.Equal(t, websocket.BinaryMessage, ty) - var env Envelope + var env terminal.Envelope err = proto.Unmarshal(raw, &env) if err != nil { @@ -8325,7 +8326,7 @@ func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session ty, raw, err := ws.ReadMessage() require.NoError(t, err) require.Equal(t, websocket.BinaryMessage, ty) - var env Envelope + var env terminal.Envelope require.NoError(t, proto.Unmarshal(raw, &env)) var sessResp siteSessionGenerateResponse @@ -8393,7 +8394,7 @@ func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack) *websocket. func validateTerminalStream(t *testing.T, ws *websocket.Conn) { t.Helper() - stream := NewTerminalStream(context.Background(), ws, utils.NewLoggerForTests()) + stream := terminal.NewStream(context.Background(), terminal.StreamConfig{WS: ws}) // here we intentionally run a command where the output we're looking // for is not present in the command itself @@ -9527,7 +9528,7 @@ func TestModeratedSession(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, peerWS.Close()) }) - peerStream := NewTerminalStream(ctx, peerWS, utils.NewLoggerForTests()) + peerStream := terminal.NewStream(ctx, terminal.StreamConfig{WS: peerWS}) require.NoError(t, waitForOutput(peerStream, "Teleport > User foo joined the session with participant mode: peer.")) @@ -9536,7 +9537,7 @@ func TestModeratedSession(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, moderatorWS.Close()) }) - moderatorStream := NewTerminalStream(ctx, moderatorWS, utils.NewLoggerForTests()) + moderatorStream := terminal.NewStream(ctx, terminal.StreamConfig{WS: moderatorWS}) require.NoError(t, waitForOutput(peerStream, "Teleport > Connecting to node over SSH")) @@ -9616,7 +9617,7 @@ func TestModeratedSessionWithMFA(t *testing.T) { handleMFAWebauthnChallenge(t, peerWS, peer.device) - peerStream := NewTerminalStream(ctx, peerWS, utils.NewLoggerForTests()) + peerStream := terminal.NewStream(ctx, terminal.StreamConfig{WS: peerWS}) require.NoError(t, waitForOutput(peerStream, "Teleport > User foo joined the session with participant mode: peer.")) @@ -9626,7 +9627,7 @@ func TestModeratedSessionWithMFA(t *testing.T) { handleMFAWebauthnChallenge(t, moderatorWS, moderator.device) - moderatorStream := NewTerminalStream(ctx, moderatorWS, utils.NewLoggerForTests()) + moderatorStream := terminal.NewStream(ctx, terminal.StreamConfig{WS: moderatorWS}) require.NoError(t, waitForOutput(peerStream, "Teleport > Connecting to node over SSH")) @@ -9642,7 +9643,7 @@ func TestModeratedSessionWithMFA(t *testing.T) { s.clock.Advance(30 * time.Second) require.NoError(t, waitForOutput(moderatorStream, "Teleport > Please tap your MFA key")) - challenge, err := moderatorStream.readChallenge(protobufMFACodec{}) + challenge, err := moderatorStream.ReadChallenge(protobufMFACodec{}) require.NoError(t, err) res, err := moderator.device.SolveAuthn(challenge) @@ -9651,7 +9652,7 @@ func TestModeratedSessionWithMFA(t *testing.T) { webauthnResBytes, err := json.Marshal(wantypes.CredentialAssertionResponseFromProto(res.GetWebauthn())) require.NoError(t, err) - envelope := &Envelope{ + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketWebauthnChallenge, Payload: string(webauthnResBytes), @@ -9675,7 +9676,7 @@ func handleMFAWebauthnChallenge(t *testing.T, ws *websocket.Conn, dev *auth.Test require.NoError(t, err) require.Equal(t, websocket.BinaryMessage, ty) - var env Envelope + var env terminal.Envelope require.NoError(t, proto.Unmarshal(raw, &env)) var challenge client.MFAAuthenticateChallenge @@ -9689,7 +9690,7 @@ func handleMFAWebauthnChallenge(t *testing.T, ws *websocket.Conn, dev *auth.Test webauthnResBytes, err := json.Marshal(wantypes.CredentialAssertionResponseFromProto(res.GetWebauthn())) require.NoError(t, err) - envelope := &Envelope{ + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketWebauthnChallenge, Payload: string(webauthnResBytes), diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 94e5b9198678d..f7747b65f67a4 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -498,7 +498,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return nil }) - go startPingLoop(ctx, ws, keepAliveInterval, h.log, nil) + go startWSPingLoop(ctx, ws, keepAliveInterval, h.log, nil) assistClient, err := assist.NewClient(ctx, h.cfg.ProxyClient, h.cfg.ProxySettings, h.cfg.OpenAIConfig) diff --git a/lib/web/command.go b/lib/web/command.go index ccea0c3deb524..65e6892f815f5 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -56,6 +56,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/teleagent" + "github.com/gravitational/teleport/lib/web/terminal" ) // summaryBufferCapacity is the summary buffer size in bytes. The summary buffer @@ -334,7 +335,7 @@ type summaryRequest struct { func (h *Handler) computeAndSendSummary( ctx context.Context, req *summaryRequest, - ws WSConn, + ws terminal.WSConn, ) (*tokens.TokenCount, error) { // Convert the map nodeId->output into a map nodeName->output namedOutput := outputByName(req.hosts, req.output) @@ -392,7 +393,7 @@ func (h *Handler) computeAndSendSummary( if err != nil { return nil, trace.Wrap(err) } - stream := NewWStream(ctx, ws, log, nil) + stream := terminal.NewWStream(ctx, ws, log, nil) _, err = stream.Write(data) return tokenCount, trace.Wrap(err) } @@ -563,10 +564,10 @@ type commandHandler struct { sshBaseHandler // stream is the websocket stream to the client. - stream *WSStream + stream *terminal.WSStream // ws a raw websocket connection to the client. - ws WSConn + ws terminal.WSConn // mfaAuthCache is a function that caches the result of a function that // returns a list of ssh.AuthMethods. It is used to cache the result of @@ -579,8 +580,8 @@ type commandHandler struct { } // sendError sends an error message to the client using the provided websocket. -func (t *sshBaseHandler) sendError(errMsg string, err error, ws WSConn) { - envelope := &Envelope{ +func (t *sshBaseHandler) sendError(errMsg string, err error, ws terminal.WSConn) { + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketError, Payload: fmt.Sprintf("%s: %s", errMsg, err.Error()), @@ -607,7 +608,7 @@ func (t *commandHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { return } - envelope := &Envelope{ + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketSessionMetadata, Payload: string(sessionMetadataResponse), @@ -629,7 +630,7 @@ func (t *commandHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { } func (t *commandHandler) handler(r *http.Request) { - t.stream = NewWStream(r.Context(), t.ws, t.log, nil) + t.stream = terminal.NewWStream(r.Context(), t.ws, t.log, nil) // Create a Teleport client, if not able to, show the reason to the user in // the terminal. @@ -643,7 +644,7 @@ func (t *commandHandler) handler(r *http.Request) { t.log.Debug("Creating websocket stream") // Start sending ping frames through websocket to the client. - go startPingLoop(r.Context(), t.ws, t.keepAliveInterval, t.log, t.Close) + go startWSPingLoop(r.Context(), t.ws, t.keepAliveInterval, t.log, t.Close) // Pump raw terminal in/out and audit events into the websocket. t.streamOutput(r.Context(), tc) @@ -675,7 +676,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl return } - if err := t.stream.SendCloseMessage(sessionEndEvent{NodeID: t.sessionData.ServerID}); err != nil { + if err := t.stream.SendCloseMessage(t.sessionData.ServerID); err != nil { t.log.WithError(err).Error("Unable to send close event to web client.") return } @@ -686,7 +687,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl // connectToNodeWithMFA attempts to perform the mfa ceremony and then dial the // host with the retrieved single use certs. // If called multiple times, the mfa ceremony will only be performed once. -func (t *commandHandler) connectToNodeWithMFA(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) { +func (t *commandHandler) connectToNodeWithMFA(ctx context.Context, ws terminal.WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) { authMethods, err := t.mfaAuthCache(func() ([]ssh.AuthMethod, error) { // perform mfa ceremony and retrieve new certs authMethods, err := t.issueSessionMFACerts(ctx, tc, t.stream) @@ -710,7 +711,7 @@ func (t *commandHandler) Close() error { } // makeClient builds a *client.TeleportClient for the connection. -func (t *commandHandler) makeClient(ctx context.Context, ws WSConn) (*client.TeleportClient, error) { +func (t *commandHandler) makeClient(ctx context.Context, ws terminal.WSConn) (*client.TeleportClient, error) { ctx, span := tracing.DefaultProvider().Tracer("command").Start(ctx, "commandHandler/makeClient") defer span.End() diff --git a/lib/web/command_test.go b/lib/web/command_test.go index 2b3c7bc9ca24f..6bdee9e72f652 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -50,6 +50,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/web/terminal" ) const ( @@ -81,7 +82,7 @@ func TestExecuteCommand(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) - stream := NewWStream(context.Background(), ws, utils.NewLoggerForTests(), nil) + stream := terminal.NewWStream(context.Background(), ws, utils.NewLoggerForTests(), nil) require.NoError(t, waitForCommandOutput(stream, "teleport")) } @@ -128,7 +129,7 @@ func TestExecuteCommandHistory(t *testing.T) { ws, _, err := s.makeCommand(t, authPack, conversationID) require.NoError(t, err) - stream := NewWStream(ctx, ws, utils.NewLoggerForTests(), nil) + stream := terminal.NewWStream(ctx, ws, utils.NewLoggerForTests(), nil) // When command executes require.NoError(t, waitForCommandOutput(stream, "teleport")) @@ -301,7 +302,7 @@ func (s *WebSuite) makeCommand(t *testing.T, pack *authPack, conversationID uuid return nil, nil, trace.Wrap(err) } require.Equal(t, websocket.BinaryMessage, ty) - var env Envelope + var env terminal.Envelope err = proto.Unmarshal(raw, &env) if err != nil { @@ -402,7 +403,7 @@ func (r *wsReader) Read(p []byte) (int, error) { return 0, trace.Wrap(err) } - var envelope Envelope + var envelope terminal.Envelope if err := proto.Unmarshal(data, &envelope); err != nil { return 0, trace.Errorf("Unable to parse message payload %v", err) } diff --git a/lib/web/command_utils.go b/lib/web/command_utils.go index e9664c1a2f1f2..ec874348b4485 100644 --- a/lib/web/command_utils.go +++ b/lib/web/command_utils.go @@ -21,33 +21,12 @@ package web import ( "encoding/json" "io" - "net" "sync" - "time" "github.com/gravitational/trace" -) - -// WSConn is a gorilla/websocket minimal interface used by our web implementation. -// This interface exists to override the default websocket.Conn implementation, -// currently used by noopCloserWS to prevent WS being closed by wrapping stream. -type WSConn interface { - Close() error - - LocalAddr() net.Addr - RemoteAddr() net.Addr - WriteControl(messageType int, data []byte, deadline time.Time) error - WriteMessage(messageType int, data []byte) error - ReadMessage() (messageType int, p []byte, err error) - SetReadLimit(limit int64) - SetReadDeadline(t time.Time) error - - PongHandler() func(appData string) error - SetPongHandler(h func(appData string) error) - CloseHandler() func(code int, text string) error - SetCloseHandler(h func(code int, text string) error) -} + "github.com/gravitational/teleport/lib/web/terminal" +) const ( EnvelopeTypeStdout = "stdout" @@ -105,7 +84,7 @@ func newPayloadWriter(nodeID, outputName string, stream io.Writer) *payloadWrite // by any underlying code as we want to keep the connection open until the command // is executed on all nodes and a single failure should not close the connection. type noopCloserWS struct { - WSConn + terminal.WSConn } // Close does nothing. @@ -123,7 +102,7 @@ func (ws *noopCloserWS) Close() error { // This would prevent the pong handler from being called. type syncRWWSConn struct { // WSConn the underlying websocket connection. - WSConn + terminal.WSConn // rmtx is a mutex used to serialize reads. rmtx sync.Mutex // wmtx is a mutex used to serialize writes. diff --git a/lib/web/desktop.go b/lib/web/desktop.go index 8d29f2b9562a2..7bd5945bd4912 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -337,7 +337,7 @@ func (h *Handler) performMFACeremony(ctx context.Context, authClient authclient. codec := tdpMFACodec{} // Send the challenge over the socket. - msg, err := codec.encode( + msg, err := codec.Encode( &client.MFAAuthenticateChallenge{ WebauthnChallenge: wantypes.CredentialAssertionFromProto(c.WebauthnChallenge), }, @@ -361,7 +361,7 @@ func (h *Handler) performMFACeremony(ctx context.Context, authClient authclient. return nil, trace.BadParameter("received unexpected web socket message type %d", ty) } - assertion, err := codec.decodeResponse(buf, defaults.WebsocketWebauthnChallenge) + assertion, err := codec.DecodeResponse(buf, defaults.WebsocketWebauthnChallenge) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/fuzz_test.go b/lib/web/fuzz_test.go index a14ac0df971d5..f66edb44a3c86 100644 --- a/lib/web/fuzz_test.go +++ b/lib/web/fuzz_test.go @@ -73,7 +73,7 @@ func FuzzTdpMFACodecDecodeChallenge(f *testing.F) { f.Fuzz(func(t *testing.T, buf []byte) { require.NotPanics(t, func() { codec := tdpMFACodec{} - _, _ = codec.decodeChallenge(buf, "") + _, _ = codec.DecodeChallenge(buf, "") }) }) } @@ -102,7 +102,7 @@ func FuzzTdpMFACodecDecodeResponse(f *testing.F) { f.Fuzz(func(t *testing.T, buf []byte) { require.NotPanics(t, func() { codec := tdpMFACodec{} - _, _ = codec.decodeResponse(buf, "") + _, _ = codec.DecodeResponse(buf, "") }) }) } diff --git a/lib/web/mfa_codec.go b/lib/web/mfa_codec.go index ae11a8ca6cd71..370f527c36455 100644 --- a/lib/web/mfa_codec.go +++ b/lib/web/mfa_codec.go @@ -29,31 +29,19 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/desktop/tdp" "github.com/gravitational/teleport/lib/web/mfajson" + "github.com/gravitational/teleport/lib/web/terminal" ) -// mfaCodec converts MFA challenges/responses between their native types and a format -// suitable for being sent over a network connection. -type mfaCodec interface { - // encode converts an MFA challenge to wire format - encode(chal *client.MFAAuthenticateChallenge, envelopeType string) ([]byte, error) - - // decodeChallenge parses an MFA authentication challenge - decodeChallenge(bytes []byte, envelopeType string) (*authproto.MFAAuthenticateChallenge, error) - - // decodeResponse parses an MFA authentication response - decodeResponse(bytes []byte, envelopeType string) (*authproto.MFAAuthenticateResponse, error) -} - // protobufMFACodec converts MFA challenges and responses to the protobuf // format used by SSH web sessions type protobufMFACodec struct{} -func (protobufMFACodec) encode(chal *client.MFAAuthenticateChallenge, envelopeType string) ([]byte, error) { +func (protobufMFACodec) Encode(chal *client.MFAAuthenticateChallenge, envelopeType string) ([]byte, error) { jsonBytes, err := json.Marshal(chal) if err != nil { return nil, trace.Wrap(err) } - envelope := &Envelope{ + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: envelopeType, Payload: string(jsonBytes), @@ -65,11 +53,11 @@ func (protobufMFACodec) encode(chal *client.MFAAuthenticateChallenge, envelopeTy return protoBytes, nil } -func (protobufMFACodec) decodeResponse(bytes []byte, envelopeType string) (*authproto.MFAAuthenticateResponse, error) { +func (protobufMFACodec) DecodeResponse(bytes []byte, envelopeType string) (*authproto.MFAAuthenticateResponse, error) { return mfajson.Decode(bytes, envelopeType) } -func (protobufMFACodec) decodeChallenge(bytes []byte, envelopeType string) (*authproto.MFAAuthenticateChallenge, error) { +func (protobufMFACodec) DecodeChallenge(bytes []byte, envelopeType string) (*authproto.MFAAuthenticateChallenge, error) { var challenge client.MFAAuthenticateChallenge if err := json.Unmarshal(bytes, &challenge); err != nil { return nil, trace.Wrap(err) @@ -84,7 +72,7 @@ func (protobufMFACodec) decodeChallenge(bytes []byte, envelopeType string) (*aut // Protocol (TDP) messages used by Desktop Access web sessions type tdpMFACodec struct{} -func (tdpMFACodec) encode(chal *client.MFAAuthenticateChallenge, envelopeType string) ([]byte, error) { +func (tdpMFACodec) Encode(chal *client.MFAAuthenticateChallenge, envelopeType string) ([]byte, error) { switch envelopeType { case defaults.WebsocketWebauthnChallenge: default: @@ -99,7 +87,7 @@ func (tdpMFACodec) encode(chal *client.MFAAuthenticateChallenge, envelopeType st return tdpMsg.Encode() } -func (tdpMFACodec) decodeResponse(buf []byte, envelopeType string) (*authproto.MFAAuthenticateResponse, error) { +func (tdpMFACodec) DecodeResponse(buf []byte, envelopeType string) (*authproto.MFAAuthenticateResponse, error) { if len(buf) == 0 { return nil, trace.BadParameter("empty MFA message received") } @@ -113,7 +101,7 @@ func (tdpMFACodec) decodeResponse(buf []byte, envelopeType string) (*authproto.M return msg.MFAAuthenticateResponse, nil } -func (tdpMFACodec) decodeChallenge(buf []byte, envelopeType string) (*authproto.MFAAuthenticateChallenge, error) { +func (tdpMFACodec) DecodeChallenge(buf []byte, envelopeType string) (*authproto.MFAAuthenticateChallenge, error) { if len(buf) == 0 { return nil, trace.BadParameter("empty MFA message received") } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index c826c0563c420..2117a79a639ac 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -20,7 +20,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "net" "net/http" @@ -38,8 +37,6 @@ import ( "github.com/sirupsen/logrus" oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" - "golang.org/x/text/encoding" - "golang.org/x/text/encoding/unicode" "github.com/gravitational/teleport" authproto "github.com/gravitational/teleport/api/client/proto" @@ -62,6 +59,7 @@ import ( "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/diagnostics/latency" + "github.com/gravitational/teleport/lib/web/terminal" ) // TerminalRequest describes a request to create a web-based terminal @@ -286,7 +284,7 @@ type TerminalHandler struct { // stream manages sending and receiving [Envelope] to the UI // for the duration of the session - stream *TerminalStream + stream *terminal.Stream // tracker is the session tracker of the session being joined. May be nil // if the user is not joining a session. tracker types.SessionTracker @@ -337,7 +335,7 @@ func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - envelope := &Envelope{ + envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketSessionMetadata, Payload: string(sessionMetadataResponse), @@ -366,11 +364,7 @@ func (t *TerminalHandler) Close() error { return } - if t.stream.sshSession != nil { - err = trace.NewAggregate(t.stream.sshSession.Close(), t.stream.Close()) - } else { - err = trace.Wrap(t.stream.Close()) - } + err = trace.Wrap(t.stream.Close()) }) return trace.Wrap(err) } @@ -391,14 +385,14 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { tctx := oteltrace.ContextWithRemoteSpanContext(context.Background(), oteltrace.SpanContextFromContext(r.Context())) ctx, cancel := context.WithCancel(tctx) defer cancel() - t.stream = NewTerminalStream(ctx, ws, t.log) + t.stream = terminal.NewStream(ctx, terminal.StreamConfig{WS: ws, Logger: t.log}) // Create a Teleport client, if not able to, show the reason to the user in // the terminal. tc, err := t.makeClient(ctx, t.stream, ws.RemoteAddr().String()) if err != nil { t.log.WithError(err).Info("Failed creating a client for session") - t.stream.writeError(err.Error()) + t.stream.WriteError(err.Error()) return } @@ -419,7 +413,7 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { }) // Start sending ping frames through websocket to client. - go startPingLoop(ctx, ws, t.keepAliveInterval, t.log, t.Close) + go startWSPingLoop(ctx, ws, t.keepAliveInterval, t.log, t.Close) // Pump raw terminal in/out and audit events into the websocket. go t.streamEvents(ctx, tc) @@ -429,28 +423,17 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { t.log.Debug("Closing websocket stream") } -// SSHSessionLatencyStats contain latency measurements for both -// legs of an ssh connection established via the Web UI. -type SSHSessionLatencyStats struct { - // WebSocket measures the round trip time for a ping/pong via the websocket - // established between the client and the Proxy. - WebSocket int64 `json:"ws"` - // SSH measures the round trip time for a keepalive@openssh.com request via the - // connection established between the Proxy and the target host. - SSH int64 `json:"ssh"` -} - type stderrWriter struct { - stream *TerminalStream + stream *terminal.Stream } func (s stderrWriter) Write(b []byte) (int, error) { - s.stream.writeError(string(b)) + s.stream.WriteError(string(b)) return len(b), nil } // makeClient builds a *client.TeleportClient for the connection. -func (t *TerminalHandler) makeClient(ctx context.Context, stream *TerminalStream, clientAddr string) (*client.TeleportClient, error) { +func (t *TerminalHandler) makeClient(ctx context.Context, stream *terminal.Stream, clientAddr string) (*client.TeleportClient, error) { ctx, span := tracing.DefaultProvider().Tracer("terminal").Start(ctx, "terminal/makeClient") defer span.End() @@ -488,7 +471,7 @@ func (t *TerminalHandler) makeClient(ctx context.Context, stream *TerminalStream // used to update all other parties window size to that of the web client and // to allow future window changes. tc.OnShellCreated = func(s *tracessh.Session, c *tracessh.Client, _ io.ReadWriteCloser) (bool, error) { - t.stream.sessionCreated(s) + t.stream.SessionCreated(s) // The web session was closed by the client while the ssh connection was being established. // Attempt to close the SSH session instead of proceeding with the window change request. @@ -511,7 +494,7 @@ func (t *TerminalHandler) makeClient(ctx context.Context, stream *TerminalStream // used to access nodes which require per-session mfa. The ceremony is performed directly // to make use of the authProvider already established for the session instead of leveraging // the TeleportClient which would require dialing the auth server a second time. -func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.TeleportClient, wsStream *WSStream) ([]ssh.AuthMethod, error) { +func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.TeleportClient, wsStream *terminal.WSStream) ([]ssh.AuthMethod, error) { ctx, span := t.tracer.Start(ctx, "terminal/issueSessionMFACerts") defer span.End() @@ -653,7 +636,7 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te return []ssh.AuthMethod{am}, nil } -func promptMFAChallenge(stream *WSStream, codec mfaCodec) client.PromptMFAFunc { +func promptMFAChallenge(stream *terminal.WSStream, codec terminal.MFACodec) client.PromptMFAFunc { return func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) { var challenge *client.MFAAuthenticateChallenge @@ -667,23 +650,23 @@ func promptMFAChallenge(stream *WSStream, codec mfaCodec) client.PromptMFAFunc { return nil, trace.AccessDenied("only hardware keys are supported on the web terminal, please register a hardware device to connect to this server") } - if err := stream.writeChallenge(challenge, codec); err != nil { + if err := stream.WriteChallenge(challenge, codec); err != nil { return nil, trace.Wrap(err) } - resp, err := stream.readChallengeResponse(codec) + resp, err := stream.ReadChallengeResponse(codec) return resp, trace.Wrap(err) } } -type connectWithMFAFn = func(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) +type connectWithMFAFn = func(ctx context.Context, ws terminal.WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) // connectToHost establishes a connection to the target host. To reduce connection // latency if per session mfa is required, connections are tried with the existing // certs and with single use certs after completing the mfa ceremony. Only one of // the operations will succeed, and if per session mfa will not gain access to the // target it will abort before prompting a user to perform the ceremony. -func (t *sshBaseHandler) connectToHost(ctx context.Context, ws WSConn, tc *client.TeleportClient, connectToNodeWithMFA connectWithMFAFn) (*client.NodeClient, error) { +func (t *sshBaseHandler) connectToHost(ctx context.Context, ws terminal.WSConn, tc *client.TeleportClient, connectToNodeWithMFA connectWithMFAFn) (*client.NodeClient, error) { ctx, span := t.tracer.Start(ctx, "terminal/connectToHost") defer span.End() @@ -774,8 +757,8 @@ func (t *sshBaseHandler) connectToHost(ctx context.Context, ws WSConn, tc *clien } } -func monitorSessionLatency(ctx context.Context, clock clockwork.Clock, stream *WSStream, sshClient *tracessh.Client) error { - wsPinger, err := latency.NewWebsocketPinger(clock, stream.ws) +func monitorSessionLatency(ctx context.Context, clock clockwork.Clock, stream *terminal.WSStream, sshClient *tracessh.Client) error { + wsPinger, err := latency.NewWebsocketPinger(clock, stream) if err != nil { return trace.Wrap(err, "creating websocket pinger") } @@ -789,7 +772,7 @@ func monitorSessionLatency(ctx context.Context, clock clockwork.Clock, stream *W ClientPinger: wsPinger, ServerPinger: sshPinger, Reporter: latency.ReporterFunc(func(ctx context.Context, statistics latency.Statistics) error { - return trace.Wrap(stream.writeLatency(SSHSessionLatencyStats{ + return trace.Wrap(stream.WriteLatency(terminal.SSHSessionLatencyStats{ WebSocket: statistics.Client, SSH: statistics.Server, })) @@ -804,16 +787,16 @@ func monitorSessionLatency(ctx context.Context, clock clockwork.Clock, stream *W return nil } -// streamTerminal opens a SSH connection to the remote host and streams +// streamTerminal opens an SSH connection to the remote host and streams // events back to the web client. func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.TeleportClient) { ctx, span := t.tracer.Start(ctx, "terminal/streamTerminal") defer span.End() - nc, err := t.connectToHost(ctx, t.stream.ws, tc, t.connectToNodeWithMFA) + nc, err := t.connectToHost(ctx, t.stream, tc, t.connectToNodeWithMFA) if err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failure connecting to host") - t.stream.writeError(err.Error()) + t.stream.WriteError(err.Error()) return } defer nc.Close() @@ -851,7 +834,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor // either an error occurs or it completes successfully. if err = nc.RunInteractiveShell(ctx, t.participantMode, t.tracker, beforeStart); err != nil { if !t.closedByClient.Load() { - t.stream.writeError(err.Error()) + t.stream.WriteError(err.Error()) } return } @@ -861,7 +844,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor } // Send close envelope to web terminal upon exit without an error. - if err := t.stream.SendCloseMessage(sessionEndEvent{NodeID: t.sessionData.ServerID}); err != nil { + if err := t.stream.SendCloseMessage(t.sessionData.ServerID); err != nil { t.log.WithError(err).Error("Unable to send close event to web client.") } @@ -875,7 +858,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor // connectToNode attempts to connect to the host with the already // provisioned certs for the user. -func (t *sshBaseHandler) connectToNode(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) { +func (t *sshBaseHandler) connectToNode(ctx context.Context, ws terminal.WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) { conn, err := t.router.DialHost(ctx, ws.RemoteAddr(), ws.LocalAddr(), t.sessionData.ServerID, strconv.Itoa(t.sessionData.ServerHostPort), tc.SiteName, accessChecker, getAgent, signer) if err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failed to dial host.") @@ -912,7 +895,7 @@ func (t *sshBaseHandler) connectToNode(ctx context.Context, ws WSConn, tc *clien // connectToNodeWithMFA attempts to perform the mfa ceremony and then dial the // host with the retrieved single use certs. -func (t *TerminalHandler) connectToNodeWithMFA(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) { +func (t *TerminalHandler) connectToNodeWithMFA(ctx context.Context, ws terminal.WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) { // perform mfa ceremony and retrieve new certs authMethods, err := t.issueSessionMFACerts(ctx, tc, t.stream.WSStream) if err != nil { @@ -924,7 +907,7 @@ func (t *TerminalHandler) connectToNodeWithMFA(ctx context.Context, ws WSConn, t // connectToNodeWithMFABase attempts to dial the host with the provided auth // methods. -func (t *sshBaseHandler) connectToNodeWithMFABase(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator, authMethods []ssh.AuthMethod) (*client.NodeClient, error) { +func (t *sshBaseHandler) connectToNodeWithMFABase(ctx context.Context, ws terminal.WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator, authMethods []ssh.AuthMethod) (*client.NodeClient, error) { sshConfig := &ssh.ClientConfig{ User: tc.HostLogin, Auth: authMethods, @@ -967,7 +950,7 @@ func (t *TerminalHandler) streamEvents(ctx context.Context, tc *client.TeleportC logger.Debug("Sending audit event to web client.") - if err := t.stream.writeAuditEvent(data); err != nil { + if err := t.stream.WriteAuditEvent(data); err != nil { if errors.Is(err, websocket.ErrCloseSent) { logger.WithError(err).Debug("Websocket was closed, no longer streaming events") return @@ -1026,461 +1009,6 @@ func serverHostPort(servername string) (string, int, error) { return host, port, nil } -func NewWStream(ctx context.Context, ws WSConn, log logrus.FieldLogger, handlers map[string]WSHandlerFunc) *WSStream { - w := &WSStream{ - log: log, - ws: ws, - encoder: unicode.UTF8.NewEncoder(), - decoder: unicode.UTF8.NewDecoder(), - rawC: make(chan Envelope, 100), - challengeC: make(chan Envelope, 1), - handlers: handlers, - } - - go w.processMessages(ctx) - - return w -} - -// NewTerminalStream creates a stream that manages reading and writing -// data over the provided [websocket.Conn] -func NewTerminalStream(ctx context.Context, ws WSConn, log logrus.FieldLogger) *TerminalStream { - t := &TerminalStream{ - sessionReadyC: make(chan struct{}), - } - - handlers := map[string]WSHandlerFunc{ - defaults.WebsocketResize: t.handleWindowResize, - defaults.WebsocketFileTransferRequest: t.handleFileTransferRequest, - defaults.WebsocketFileTransferDecision: t.handleFileTransferDecision, - } - - t.WSStream = NewWStream(ctx, ws, log, handlers) - - return t -} - -// WSHandlerFunc specifies a handler that processes received a specific -// [Envelope] received via a web socket. -type WSHandlerFunc func(context.Context, Envelope) - -// WSStream handles web socket communication with -// the frontend. -type WSStream struct { - // encoder is used to encode UTF-8 strings. - encoder *encoding.Encoder - // decoder is used to decode UTF-8 strings. - decoder *encoding.Decoder - - handlers map[string]WSHandlerFunc - // once ensures that all channels are closed at most one time. - once sync.Once - challengeC chan Envelope - rawC chan Envelope - - // buffer is a buffer used to store the remaining payload data if it did not - // fit into the buffer provided by the callee to Read method - buffer []byte - - // mu protects writes to ws - mu sync.Mutex - // ws the connection to the UI - ws WSConn - - // log holds the structured logger. - log logrus.FieldLogger -} - -// TerminalStream manages the [websocket.Conn] to the web UI -// for a terminal session. -type TerminalStream struct { - *WSStream - - // sshSession holds the "shell" SSH channel to the node. - sshSession *tracessh.Session - sessionReadyC chan struct{} -} - -// Replace \n with \r\n so the message is correctly aligned. -var replacer = strings.NewReplacer("\r\n", "\r\n", "\n", "\r\n") - -// writeError displays an error in the terminal window. -func (t *WSStream) writeError(msg string) { - if _, writeErr := replacer.WriteString(t, msg); writeErr != nil { - t.log.WithError(writeErr).Warnf("Unable to send error to terminal: %v", msg) - } -} - -func (t *WSStream) processMessages(ctx context.Context) { - defer func() { - t.close() - }() - t.ws.SetReadLimit(teleport.MaxHTTPRequestSize) - - for { - select { - case <-ctx.Done(): - return - default: - ty, bytes, err := t.ws.ReadMessage() - if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || - websocket.IsCloseError(err, websocket.CloseAbnormalClosure, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - return - } - - msg := err.Error() - if len(bytes) > 0 { - msg = string(bytes) - } - select { - case <-ctx.Done(): - default: - t.writeError(msg) - return - } - } - - if ty != websocket.BinaryMessage { - t.writeError(fmt.Sprintf("Expected binary message, got %v", ty)) - return - } - - var envelope Envelope - if err := proto.Unmarshal(bytes, &envelope); err != nil { - t.writeError(fmt.Sprintf("Unable to parse message payload %v", err)) - return - } - - switch envelope.Type { - case defaults.WebsocketClose: - return - case defaults.WebsocketWebauthnChallenge: - select { - case <-ctx.Done(): - return - case t.challengeC <- envelope: - default: - } - case defaults.WebsocketRaw: - select { - case <-ctx.Done(): - return - case t.rawC <- envelope: - default: - } - default: - if t.handlers == nil { - continue - } - - handler, ok := t.handlers[envelope.Type] - if !ok { - t.log.Warnf("Received web socket envelope with unknown type %v", envelope.Type) - continue - } - - go handler(ctx, envelope) - } - } - } -} - -// handleWindowResize receives window resize events and forwards -// them to the SSH session. -func (t *TerminalStream) handleWindowResize(ctx context.Context, envelope Envelope) { - select { - case <-ctx.Done(): - return - case <-t.sessionReadyC: - } - - if t.sshSession == nil { - return - } - - var e map[string]interface{} - err := json.Unmarshal([]byte(envelope.Payload), &e) - if err != nil { - t.log.Warnf("Failed to parse resize payload: %v", err) - return - } - - size, ok := e["size"].(string) - if !ok { - t.log.Errorf("expected size to be of type string, got type %T instead", size) - return - } - - params, err := session.UnmarshalTerminalParams(size) - if err != nil { - t.log.Warnf("Failed to retrieve terminal size: %v", err) - return - } - - // nil params indicates the channel was closed - if params == nil { - return - } - - if err := t.sshSession.WindowChange(ctx, params.H, params.W); err != nil { - t.log.Error(err) - } -} - -func (t *TerminalStream) handleFileTransferDecision(ctx context.Context, envelope Envelope) { - select { - case <-ctx.Done(): - return - case <-t.sessionReadyC: - } - - if t.sshSession == nil { - return - } - - var e utils.Fields - err := json.Unmarshal([]byte(envelope.Payload), &e) - if err != nil { - return - } - approved, ok := e["approved"].(bool) - if !ok { - t.log.Error("Unable to find approved status on response") - return - } - - if approved { - err = t.sshSession.ApproveFileTransferRequest(ctx, e.GetString("requestId")) - } else { - err = t.sshSession.DenyFileTransferRequest(ctx, e.GetString("requestId")) - } - if err != nil { - t.log.WithError(err).Error("Unable to respond to file transfer request") - } -} - -func (t *TerminalStream) handleFileTransferRequest(ctx context.Context, envelope Envelope) { - select { - case <-ctx.Done(): - return - case <-t.sessionReadyC: - } - - if t.sshSession == nil { - return - } - - var e utils.Fields - err := json.Unmarshal([]byte(envelope.Payload), &e) - if err != nil { - return - } - download, ok := e["download"].(bool) - if !ok { - t.log.Error("Unable to find download param in response") - return - } - - if err := t.sshSession.RequestFileTransfer(ctx, tracessh.FileTransferReq{ - Download: download, - Location: e.GetString("location"), - Filename: e.GetString("filename"), - }); err != nil { - t.log.WithError(err).Error("Unable to request file transfer") - } -} - -func (t *TerminalStream) sessionCreated(s *tracessh.Session) { - t.sshSession = s - close(t.sessionReadyC) -} - -// writeChallenge encodes and writes the challenge to the -// websocket in the correct format. -func (t *WSStream) writeChallenge(challenge *client.MFAAuthenticateChallenge, codec mfaCodec) error { - // Send the challenge over the socket. - msg, err := codec.encode(challenge, defaults.WebsocketWebauthnChallenge) - if err != nil { - return trace.Wrap(err) - } - - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.ws.WriteMessage(websocket.BinaryMessage, msg)) -} - -// readChallengeResponse reads and decodes the challenge response from the -// websocket in the correct format. -func (t *WSStream) readChallengeResponse(codec mfaCodec) (*authproto.MFAAuthenticateResponse, error) { - envelope, ok := <-t.challengeC - if !ok { - return nil, io.EOF - } - resp, err := codec.decodeResponse([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge) - return resp, trace.Wrap(err) -} - -// readChallenge reads and decodes the challenge from the -// websocket in the correct format. -func (t *WSStream) readChallenge(codec mfaCodec) (*authproto.MFAAuthenticateChallenge, error) { - envelope, ok := <-t.challengeC - if !ok { - return nil, io.EOF - } - challenge, err := codec.decodeChallenge([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge) - return challenge, trace.Wrap(err) -} - -// writeAuditEvent encodes and writes the audit event to the -// websocket in the correct format. -func (t *WSStream) writeAuditEvent(event []byte) error { - // UTF-8 encode the error message and then wrap it in a raw envelope. - encodedPayload, err := t.encoder.String(string(event)) - if err != nil { - return trace.Wrap(err) - } - - envelope := &Envelope{ - Version: defaults.WebsocketVersion, - Type: defaults.WebsocketAudit, - Payload: encodedPayload, - } - - envelopeBytes, err := proto.Marshal(envelope) - if err != nil { - return trace.Wrap(err) - } - - // Send bytes over the websocket to the web client. - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes)) -} - -func (t *WSStream) writeLatency(latency SSHSessionLatencyStats) error { - data, err := json.Marshal(latency) - if err != nil { - return trace.Wrap(err) - } - - encodedPayload, err := t.encoder.String(string(data)) - if err != nil { - return trace.Wrap(err) - } - - envelope := &Envelope{ - Version: defaults.WebsocketVersion, - Type: defaults.WebsocketLatency, - Payload: encodedPayload, - } - - envelopeBytes, err := proto.Marshal(envelope) - if err != nil { - return trace.Wrap(err) - } - - // Send bytes over the websocket to the web client. - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes)) -} - -// Write wraps the data bytes in a raw envelope and sends. -func (t *WSStream) Write(data []byte) (n int, err error) { - // UTF-8 encode data and wrap it in a raw envelope. - encodedPayload, err := t.encoder.String(string(data)) - if err != nil { - return 0, trace.Wrap(err) - } - envelope := &Envelope{ - Version: defaults.WebsocketVersion, - Type: defaults.WebsocketRaw, - Payload: encodedPayload, - } - envelopeBytes, err := proto.Marshal(envelope) - if err != nil { - return 0, trace.Wrap(err) - } - - // Send bytes over the websocket to the web client. - t.mu.Lock() - err = t.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes) - t.mu.Unlock() - if err != nil { - return 0, trace.Wrap(err) - } - - return len(data), nil -} - -// Read provides data received from [defaults.WebsocketRaw] envelopes. If -// the previous envelope was not consumed in the last read, any remaining data -// is returned prior to processing the next envelope. -func (t *WSStream) Read(out []byte) (int, error) { - if len(t.buffer) > 0 { - n := copy(out, t.buffer) - if n == len(t.buffer) { - t.buffer = []byte{} - } else { - t.buffer = t.buffer[n:] - } - return n, nil - } - - envelope, ok := <-t.rawC - if !ok { - return 0, io.EOF - } - - data, err := t.decoder.Bytes([]byte(envelope.Payload)) - if err != nil { - return 0, trace.Wrap(err) - } - - n := copy(out, data) - // if the payload size is greater than [out], store the remaining - // part in the buffer to be processed on the next Read call - if len(data) > n { - t.buffer = data[n:] - } - return n, nil -} - -// SendCloseMessage sends a close message on the web socket. -func (t *WSStream) SendCloseMessage(event sessionEndEvent) error { - sessionMetadataPayload, err := json.Marshal(&event) - if err != nil { - return trace.Wrap(err) - } - - envelope := &Envelope{ - Version: defaults.WebsocketVersion, - Type: defaults.WebsocketClose, - Payload: string(sessionMetadataPayload), - } - envelopeBytes, err := proto.Marshal(envelope) - if err != nil { - return trace.Wrap(err) - } - - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes)) -} - -func (t *WSStream) close() { - t.once.Do(func() { - close(t.rawC) - close(t.challengeC) - }) -} - -// Close sends a close message on the web socket and closes the web socket. -func (t *WSStream) Close() error { - return trace.Wrap(t.ws.Close()) -} - // deadlineForInterval returns a suitable network read deadline for a given ping interval. // We chose to take the current time plus twice the interval to allow the timeframe of one interval // to wait for a returned pong message. diff --git a/lib/web/envelope.pb.go b/lib/web/terminal/envelope.pb.go similarity index 88% rename from lib/web/envelope.pb.go rename to lib/web/terminal/envelope.pb.go index 068cb70552637..10fcae2a7f60a 100644 --- a/lib/web/envelope.pb.go +++ b/lib/web/terminal/envelope.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gogo. DO NOT EDIT. -// source: teleport/lib/web/envelope.proto +// source: teleport/lib/web/terminal/envelope.proto -package web +package terminal import ( fmt "fmt" @@ -41,7 +41,7 @@ func (m *Envelope) Reset() { *m = Envelope{} } func (m *Envelope) String() string { return proto.CompactTextString(m) } func (*Envelope) ProtoMessage() {} func (*Envelope) Descriptor() ([]byte, []int) { - return fileDescriptor_ee3212c7e303fe4c, []int{0} + return fileDescriptor_016ae8368e6afaa7, []int{0} } func (m *Envelope) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -95,21 +95,24 @@ func init() { proto.RegisterType((*Envelope)(nil), "teleport.lib.web.Envelope") } -func init() { proto.RegisterFile("teleport/lib/web/envelope.proto", fileDescriptor_ee3212c7e303fe4c) } +func init() { + proto.RegisterFile("teleport/lib/web/terminal/envelope.proto", fileDescriptor_016ae8368e6afaa7) +} -var fileDescriptor_ee3212c7e303fe4c = []byte{ - // 174 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x2f, 0x49, 0xcd, 0x49, - 0x2d, 0xc8, 0x2f, 0x2a, 0xd1, 0xcf, 0xc9, 0x4c, 0xd2, 0x2f, 0x4f, 0x4d, 0xd2, 0x4f, 0xcd, 0x2b, - 0x4b, 0xcd, 0xc9, 0x2f, 0x48, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x80, 0x29, 0xd0, - 0xcb, 0xc9, 0x4c, 0xd2, 0x2b, 0x4f, 0x4d, 0x52, 0x0a, 0xe2, 0xe2, 0x70, 0x85, 0xaa, 0x11, 0x92, - 0xe0, 0x62, 0x0f, 0x4b, 0x2d, 0x2a, 0xce, 0xcc, 0xcf, 0x93, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, - 0x82, 0x71, 0x85, 0x84, 0xb8, 0x58, 0x42, 0x2a, 0x0b, 0x52, 0x25, 0x98, 0xc0, 0xc2, 0x60, 0x36, - 0x48, 0x75, 0x40, 0x62, 0x65, 0x4e, 0x7e, 0x62, 0x8a, 0x04, 0x33, 0x44, 0x35, 0x94, 0xeb, 0x64, - 0x7e, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f, 0x1e, 0xc9, 0x31, 0x46, 0x69, 0xa6, - 0x67, 0x96, 0x64, 0x94, 0x26, 0xe9, 0x25, 0xe7, 0xe7, 0xea, 0xa7, 0x17, 0x25, 0x96, 0x65, 0x96, - 0x24, 0x96, 0x64, 0xe6, 0xe7, 0x25, 0xe6, 0xe8, 0xa3, 0xbb, 0x36, 0x89, 0x0d, 0xec, 0x4a, 0x63, - 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, 0xa3, 0x18, 0x14, 0xb9, 0xc8, 0x00, 0x00, 0x00, +var fileDescriptor_016ae8368e6afaa7 = []byte{ + // 181 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x28, 0x49, 0xcd, 0x49, + 0x2d, 0xc8, 0x2f, 0x2a, 0xd1, 0xcf, 0xc9, 0x4c, 0xd2, 0x2f, 0x4f, 0x4d, 0xd2, 0x2f, 0x49, 0x2d, + 0xca, 0xcd, 0xcc, 0x4b, 0xcc, 0xd1, 0x4f, 0xcd, 0x2b, 0x4b, 0xcd, 0xc9, 0x2f, 0x48, 0xd5, 0x2b, + 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x80, 0xa9, 0xd4, 0xcb, 0xc9, 0x4c, 0xd2, 0x2b, 0x4f, 0x4d, + 0x52, 0x0a, 0xe2, 0xe2, 0x70, 0x85, 0xaa, 0x11, 0x92, 0xe0, 0x62, 0x0f, 0x4b, 0x2d, 0x2a, 0xce, + 0xcc, 0xcf, 0x93, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x82, 0x71, 0x85, 0x84, 0xb8, 0x58, 0x42, + 0x2a, 0x0b, 0x52, 0x25, 0x98, 0xc0, 0xc2, 0x60, 0x36, 0x48, 0x75, 0x40, 0x62, 0x65, 0x4e, 0x7e, + 0x62, 0x8a, 0x04, 0x33, 0x44, 0x35, 0x94, 0xeb, 0xe4, 0x70, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, + 0x72, 0x8c, 0x0f, 0x1e, 0xc9, 0x31, 0x46, 0x19, 0xa5, 0x67, 0x96, 0x64, 0x94, 0x26, 0xe9, 0x25, + 0xe7, 0xe7, 0xea, 0xa7, 0x17, 0x25, 0x96, 0x65, 0x96, 0x24, 0x96, 0x64, 0xe6, 0x83, 0x1c, 0x87, + 0xd3, 0xd9, 0x49, 0x6c, 0x60, 0xe7, 0x1a, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0x44, 0xc0, 0x7f, + 0x05, 0xda, 0x00, 0x00, 0x00, } func (m *Envelope) Marshal() (dAtA []byte, err error) { diff --git a/lib/web/terminal/terminal.go b/lib/web/terminal/terminal.go new file mode 100644 index 0000000000000..2ad18331a673b --- /dev/null +++ b/lib/web/terminal/terminal.go @@ -0,0 +1,596 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package terminal + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/gogo/protobuf/proto" + "github.com/gorilla/websocket" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "golang.org/x/text/encoding" + "golang.org/x/text/encoding/unicode" + + "github.com/gravitational/teleport" + authproto "github.com/gravitational/teleport/api/client/proto" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/utils" +) + +// WSConn is a gorilla/websocket minimal interface used by our web implementation. +// This interface exists to override the default websocket.Conn implementation, +// currently used by noopCloserWS to prevent WS being closed by wrapping stream. +type WSConn interface { + Close() error + + LocalAddr() net.Addr + RemoteAddr() net.Addr + + WriteControl(messageType int, data []byte, deadline time.Time) error + WriteMessage(messageType int, data []byte) error + ReadMessage() (messageType int, p []byte, err error) + SetReadLimit(limit int64) + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error + + PongHandler() func(appData string) error + SetPongHandler(h func(appData string) error) + CloseHandler() func(code int, text string) error + SetCloseHandler(h func(code int, text string) error) +} + +// WSHandlerFunc specifies a handler that processes received a specific +// [Envelope] received via a web socket. +type WSHandlerFunc func(context.Context, Envelope) + +// WSStream handles web socket communication with +// the frontend. +type WSStream struct { + // encoder is used to encode UTF-8 strings. + encoder *encoding.Encoder + // decoder is used to decode UTF-8 strings. + decoder *encoding.Decoder + + handlers map[string]WSHandlerFunc + // once ensures that all channels are closed at most one time. + once sync.Once + challengeC chan Envelope + rawC chan Envelope + + // buffer is a buffer used to store the remaining payload data if it did not + // fit into the buffer provided by the callee to Read method + buffer []byte + + // mu protects writes to ws + mu sync.Mutex + // ws the connection to the UI + WSConn + + // log holds the structured logger. + log logrus.FieldLogger +} + +// Replace \n with \r\n so the message is correctly aligned. +var replacer = strings.NewReplacer("\r\n", "\r\n", "\n", "\r\n") + +// WriteError displays an error in the terminal window. +func (t *WSStream) WriteError(msg string) { + if _, writeErr := replacer.WriteString(t, msg); writeErr != nil { + t.log.WithError(writeErr).Warnf("Unable to send error to terminal: %v", msg) + } +} + +func (t *WSStream) SetReadDeadline(deadline time.Time) error { + return t.WSConn.SetReadDeadline(deadline) +} + +func isOKWebsocketCloseError(err error) bool { + return websocket.IsCloseError(err, + websocket.CloseAbnormalClosure, + websocket.CloseGoingAway, + websocket.CloseNormalClosure, + ) +} + +func (t *WSStream) processMessages(ctx context.Context) { + defer func() { + t.close() + }() + t.WSConn.SetReadLimit(teleport.MaxHTTPRequestSize) + + for { + select { + case <-ctx.Done(): + return + default: + ty, bytes, err := t.WSConn.ReadMessage() + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || isOKWebsocketCloseError(err) { + return + } + + msg := err.Error() + if len(bytes) > 0 { + msg = string(bytes) + } + select { + case <-ctx.Done(): + default: + t.WriteError(msg) + return + } + } + + if ty != websocket.BinaryMessage { + t.WriteError(fmt.Sprintf("Expected binary message, got %v", ty)) + return + } + + var envelope Envelope + if err := proto.Unmarshal(bytes, &envelope); err != nil { + t.WriteError(fmt.Sprintf("Unable to parse message payload %v", err)) + return + } + + switch envelope.Type { + case defaults.WebsocketClose: + return + case defaults.WebsocketWebauthnChallenge: + select { + case <-ctx.Done(): + return + case t.challengeC <- envelope: + default: + } + case defaults.WebsocketRaw: + select { + case <-ctx.Done(): + return + case t.rawC <- envelope: + default: + } + default: + if t.handlers == nil { + continue + } + + handler, ok := t.handlers[envelope.Type] + if !ok { + t.log.Warnf("Received web socket envelope with unknown type %v", envelope.Type) + continue + } + + go handler(ctx, envelope) + } + } + } +} + +// MFACodec converts MFA challenges/responses between their native types and a format +// suitable for being sent over a network connection. +type MFACodec interface { + // Encode converts an MFA challenge to wire format + Encode(chal *client.MFAAuthenticateChallenge, envelopeType string) ([]byte, error) + + // DecodeChallenge parses an MFA authentication challenge + DecodeChallenge(bytes []byte, envelopeType string) (*authproto.MFAAuthenticateChallenge, error) + + // DecodeResponse parses an MFA authentication response + DecodeResponse(bytes []byte, envelopeType string) (*authproto.MFAAuthenticateResponse, error) +} + +// WriteChallenge encodes and writes the challenge to the +// websocket in the correct format. +func (t *WSStream) WriteChallenge(challenge *client.MFAAuthenticateChallenge, codec MFACodec) error { + // Send the challenge over the socket. + msg, err := codec.Encode(challenge, defaults.WebsocketWebauthnChallenge) + if err != nil { + return trace.Wrap(err) + } + + t.mu.Lock() + defer t.mu.Unlock() + return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, msg)) +} + +// ReadChallengeResponse reads and decodes the challenge response from the +// websocket in the correct format. +func (t *WSStream) ReadChallengeResponse(codec MFACodec) (*authproto.MFAAuthenticateResponse, error) { + envelope, ok := <-t.challengeC + if !ok { + return nil, io.EOF + } + resp, err := codec.DecodeResponse([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge) + return resp, trace.Wrap(err) +} + +// ReadChallenge reads and decodes the challenge from the +// websocket in the correct format. +func (t *WSStream) ReadChallenge(codec MFACodec) (*authproto.MFAAuthenticateChallenge, error) { + envelope, ok := <-t.challengeC + if !ok { + return nil, io.EOF + } + challenge, err := codec.DecodeChallenge([]byte(envelope.Payload), defaults.WebsocketWebauthnChallenge) + return challenge, trace.Wrap(err) +} + +// WriteAuditEvent encodes and writes the audit event to the +// websocket in the correct format. +func (t *WSStream) WriteAuditEvent(event []byte) error { + // UTF-8 encode the error message and then wrap it in a raw envelope. + encodedPayload, err := t.encoder.String(string(event)) + if err != nil { + return trace.Wrap(err) + } + + envelope := &Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketAudit, + Payload: encodedPayload, + } + + envelopeBytes, err := proto.Marshal(envelope) + if err != nil { + return trace.Wrap(err) + } + + // Send bytes over the websocket to the web client. + t.mu.Lock() + defer t.mu.Unlock() + return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes)) +} + +// SSHSessionLatencyStats contain latency measurements for both +// legs of an ssh connection established via the Web UI. +type SSHSessionLatencyStats struct { + // WebSocket measures the round trip time for a ping/pong via the websocket + // established between the client and the Proxy. + WebSocket int64 `json:"ws"` + // SSH measures the round trip time for a keepalive@openssh.com request via the + // connection established between the Proxy and the target host. + SSH int64 `json:"ssh"` +} + +// WriteLatency encodes and writes latency statistics. +func (t *WSStream) WriteLatency(latency SSHSessionLatencyStats) error { + data, err := json.Marshal(latency) + if err != nil { + return trace.Wrap(err) + } + + encodedPayload, err := t.encoder.String(string(data)) + if err != nil { + return trace.Wrap(err) + } + + envelope := &Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketLatency, + Payload: encodedPayload, + } + + envelopeBytes, err := proto.Marshal(envelope) + if err != nil { + return trace.Wrap(err) + } + + // Send bytes over the websocket to the web client. + t.mu.Lock() + defer t.mu.Unlock() + return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes)) +} + +// Write wraps the data bytes in a raw envelope and sends. +func (t *WSStream) Write(data []byte) (n int, err error) { + // UTF-8 encode data and wrap it in a raw envelope. + encodedPayload, err := t.encoder.String(string(data)) + if err != nil { + return 0, trace.Wrap(err) + } + envelope := &Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketRaw, + Payload: encodedPayload, + } + envelopeBytes, err := proto.Marshal(envelope) + if err != nil { + return 0, trace.Wrap(err) + } + + // Send bytes over the websocket to the web client. + t.mu.Lock() + err = t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes) + t.mu.Unlock() + if err != nil { + return 0, trace.Wrap(err) + } + + return len(data), nil +} + +// Read provides data received from [defaults.WebsocketRaw] envelopes. If +// the previous envelope was not consumed in the last read, any remaining data +// is returned prior to processing the next envelope. +func (t *WSStream) Read(out []byte) (int, error) { + if len(t.buffer) > 0 { + n := copy(out, t.buffer) + if n == len(t.buffer) { + t.buffer = []byte{} + } else { + t.buffer = t.buffer[n:] + } + return n, nil + } + + envelope, ok := <-t.rawC + if !ok { + return 0, io.EOF + } + + data, err := t.decoder.Bytes([]byte(envelope.Payload)) + if err != nil { + return 0, trace.Wrap(err) + } + + n := copy(out, data) + // if the payload size is greater than [out], store the remaining + // part in the buffer to be processed on the next Read call + if len(data) > n { + t.buffer = data[n:] + } + return n, nil +} + +// sessionEndEvent is an event sent when a session ends. +type sessionEndEvent struct { + // NodeID is the ID of the server where the session was created. + NodeID string `json:"node_id"` +} + +// SendCloseMessage sends a close message on the web socket. +func (t *WSStream) SendCloseMessage(id string) error { + sessionMetadataPayload, err := json.Marshal(&sessionEndEvent{NodeID: id}) + if err != nil { + return trace.Wrap(err) + } + + envelope := &Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketClose, + Payload: string(sessionMetadataPayload), + } + envelopeBytes, err := proto.Marshal(envelope) + if err != nil { + return trace.Wrap(err) + } + + t.mu.Lock() + defer t.mu.Unlock() + return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes)) +} + +func (t *WSStream) close() { + t.once.Do(func() { + close(t.rawC) + close(t.challengeC) + }) +} + +// Close sends a close message on the web socket and closes the web socket. +func (t *WSStream) Close() error { + return trace.Wrap(t.WSConn.Close()) +} + +// Stream manages the [websocket.Conn] to the web UI +// for a terminal session. +type Stream struct { + *WSStream + + // sshSession holds the "shell" SSH channel to the node. + sshSession *tracessh.Session + sessionReadyC chan struct{} +} + +// StreamConfig contains dependencies of a TerminalStream. +type StreamConfig struct { + // The websocket to operate over. Required. + WS WSConn + // A logger to emit log messages. Optional. + Logger logrus.FieldLogger + // A custom set of handlers to process messages received + // over the websocket. Optional. + Handlers map[string]WSHandlerFunc +} + +func NewWStream(ctx context.Context, ws WSConn, log logrus.FieldLogger, handlers map[string]WSHandlerFunc) *WSStream { + w := &WSStream{ + log: log, + WSConn: ws, + encoder: unicode.UTF8.NewEncoder(), + decoder: unicode.UTF8.NewDecoder(), + rawC: make(chan Envelope, 100), + challengeC: make(chan Envelope, 1), + handlers: handlers, + } + + go w.processMessages(ctx) + + return w +} + +// NewStream creates a stream that manages reading and writing +// data over the provided [websocket.Conn] +func NewStream(ctx context.Context, cfg StreamConfig) *Stream { + t := &Stream{ + sessionReadyC: make(chan struct{}), + } + + if cfg.Handlers == nil { + cfg.Handlers = map[string]WSHandlerFunc{} + } + + if _, ok := cfg.Handlers[defaults.WebsocketResize]; !ok { + cfg.Handlers[defaults.WebsocketResize] = t.handleWindowResize + } + + if _, ok := cfg.Handlers[defaults.WebsocketFileTransferRequest]; !ok { + cfg.Handlers[defaults.WebsocketFileTransferRequest] = t.handleFileTransferRequest + } + + if _, ok := cfg.Handlers[defaults.WebsocketFileTransferDecision]; !ok { + cfg.Handlers[defaults.WebsocketFileTransferDecision] = t.handleFileTransferDecision + } + + if cfg.Logger == nil { + cfg.Logger = utils.NewLogger() + } + + t.WSStream = NewWStream(ctx, cfg.WS, cfg.Logger, cfg.Handlers) + + return t +} + +// handleWindowResize receives window resize events and forwards +// them to the SSH session. +func (t *Stream) handleWindowResize(ctx context.Context, envelope Envelope) { + select { + case <-ctx.Done(): + return + case <-t.sessionReadyC: + } + + if t.sshSession == nil { + return + } + + var e map[string]interface{} + err := json.Unmarshal([]byte(envelope.Payload), &e) + if err != nil { + t.log.Warnf("Failed to parse resize payload: %v", err) + return + } + + size, ok := e["size"].(string) + if !ok { + t.log.Errorf("expected size to be of type string, got type %T instead", size) + return + } + + params, err := session.UnmarshalTerminalParams(size) + if err != nil { + t.log.Warnf("Failed to retrieve terminal size: %v", err) + return + } + + // nil params indicates the channel was closed + if params == nil { + return + } + + if err := t.sshSession.WindowChange(ctx, params.H, params.W); err != nil { + t.log.Error(err) + } +} + +func (t *Stream) handleFileTransferDecision(ctx context.Context, envelope Envelope) { + select { + case <-ctx.Done(): + return + case <-t.sessionReadyC: + } + + if t.sshSession == nil { + return + } + + var e utils.Fields + err := json.Unmarshal([]byte(envelope.Payload), &e) + if err != nil { + return + } + approved, ok := e["approved"].(bool) + if !ok { + t.log.Error("Unable to find approved status on response") + return + } + + if approved { + err = t.sshSession.ApproveFileTransferRequest(ctx, e.GetString("requestId")) + } else { + err = t.sshSession.DenyFileTransferRequest(ctx, e.GetString("requestId")) + } + if err != nil { + t.log.WithError(err).Error("Unable to respond to file transfer request") + } +} + +func (t *Stream) handleFileTransferRequest(ctx context.Context, envelope Envelope) { + select { + case <-ctx.Done(): + return + case <-t.sessionReadyC: + } + + if t.sshSession == nil { + return + } + + var e utils.Fields + err := json.Unmarshal([]byte(envelope.Payload), &e) + if err != nil { + return + } + download, ok := e["download"].(bool) + if !ok { + t.log.Error("Unable to find download param in response") + return + } + + if err := t.sshSession.RequestFileTransfer(ctx, tracessh.FileTransferReq{ + Download: download, + Location: e.GetString("location"), + Filename: e.GetString("filename"), + }); err != nil { + t.log.WithError(err).Error("Unable to request file transfer") + } +} + +func (t *Stream) SessionCreated(s *tracessh.Session) { + t.sshSession = s + close(t.sessionReadyC) +} + +func (t *Stream) Close() error { + if t.sshSession != nil { + return trace.NewAggregate(t.sshSession.Close(), t.WSStream.Close()) + } else { + return trace.Wrap(t.WSStream.Close()) + } +} diff --git a/lib/web/terminal_test.go b/lib/web/terminal_test.go index 884613d0a04f8..5eaaf6b3dbd7e 100644 --- a/lib/web/terminal_test.go +++ b/lib/web/terminal_test.go @@ -29,8 +29,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/teleport/lib/web" + "github.com/gravitational/teleport/lib/web/terminal" ) // TestTerminalReadFromClosedConn verifies that Teleport recovers @@ -47,7 +46,7 @@ func TestTerminalReadFromClosedConn(t *testing.T) { t.Errorf("couldn't upgrade websocket connection: %v", err) } - envelope := web.Envelope{ + envelope := terminal.Envelope{ Type: defaults.WebsocketRaw, Payload: "hello", } @@ -64,7 +63,7 @@ func TestTerminalReadFromClosedConn(t *testing.T) { require.NoError(t, err) defer resp.Body.Close() - stream := web.NewTerminalStream(context.Background(), conn, utils.NewLoggerForTests()) + stream := terminal.NewStream(context.Background(), terminal.StreamConfig{WS: conn}) // close the stream before we attempt to read from it, // this will produce a net.ErrClosed error on the read diff --git a/lib/web/ws_io.go b/lib/web/ws_io.go index a1d0f4888a06f..68e7b3d91456c 100644 --- a/lib/web/ws_io.go +++ b/lib/web/ws_io.go @@ -61,11 +61,15 @@ func (ws *WebsocketIO) Close() error { return trace.Wrap(ws.Conn.Close()) } -// startPingLoop starts a loop that will continuously send a ping frame through the websocket +type wsPinger interface { + WriteControl(messageType int, data []byte, deadline time.Time) error +} + +// startWSPingLoop starts a loop that will continuously send a ping frame through the websocket // to prevent the connection between web client and teleport proxy from becoming idle. // Interval is determined by the keep_alive_interval config set by user (or default). // Loop will terminate when there is an error sending ping frame or when the context is canceled. -func startPingLoop(ctx context.Context, ws WSConn, keepAliveInterval time.Duration, log logrus.FieldLogger, onClose func() error) { +func startWSPingLoop(ctx context.Context, pinger wsPinger, keepAliveInterval time.Duration, log logrus.FieldLogger, onClose func() error) { log.Debugf("Starting websocket ping loop with interval %v.", keepAliveInterval) tickerCh := time.NewTicker(keepAliveInterval) defer tickerCh.Stop() @@ -76,7 +80,7 @@ func startPingLoop(ctx context.Context, ws WSConn, keepAliveInterval time.Durati // A short deadline is used here to detect a broken connection quickly. // If this is just a temporary issue, we will retry shortly anyway. deadline := time.Now().Add(time.Second) - if err := ws.WriteControl(websocket.PingMessage, nil, deadline); err != nil { + if err := pinger.WriteControl(websocket.PingMessage, nil, deadline); err != nil { log.WithError(err).Error("Unable to send ping frame to web client") if onClose != nil { if err := onClose(); err != nil { diff --git a/proto/buf.yaml b/proto/buf.yaml index 712df792ddb18..2a48f29c081c7 100644 --- a/proto/buf.yaml +++ b/proto/buf.yaml @@ -20,7 +20,7 @@ lint: ignore: # "legacy" lib protos. - teleport/lib/multiplexer/test/ping.proto - - teleport/lib/web/envelope.proto + - teleport/lib/web/terminal/envelope.proto ignore_only: # Allow only certain services to use streaming RPCs. # diff --git a/proto/teleport/lib/web/envelope.proto b/proto/teleport/lib/web/terminal/envelope.proto similarity index 93% rename from proto/teleport/lib/web/envelope.proto rename to proto/teleport/lib/web/terminal/envelope.proto index 0391e85acfc7d..2ac5f5e4e9d57 100644 --- a/proto/teleport/lib/web/envelope.proto +++ b/proto/teleport/lib/web/terminal/envelope.proto @@ -16,7 +16,7 @@ syntax = "proto3"; package teleport.lib.web; -option go_package = "github.com/gravitational/teleport/lib/web"; +option go_package = "github.com/gravitational/teleport/lib/web/terminal"; // Envelope is used to wrap and transend and receive messages between the // web client and proxy. diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 19ec11f69189f..435402a708806 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -65,7 +65,6 @@ import ( "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/prompt" "github.com/gravitational/teleport/lib/asciitable" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" "github.com/gravitational/teleport/lib/benchmark" @@ -1982,7 +1981,7 @@ func onLogin(cf *CLIConf) error { key, err := tc.Login(cf.Context) if err != nil { - if !cf.ExplicitUsername && auth.IsInvalidLocalCredentialError(err) { + if !cf.ExplicitUsername && authclient.IsInvalidLocalCredentialError(err) { fmt.Fprintf(os.Stderr, "\nhint: set the --user flag to log in as a specific user, or leave it empty to use the system user (%v)\n\n", tc.Username) } return trace.Wrap(err)