From 4df90e115c2caeec73c736794674a1cb7d6bf7dd Mon Sep 17 00:00:00 2001 From: Lonny Wong Date: Sun, 29 Oct 2023 00:17:01 +0800 Subject: [PATCH] support connection sharing #29 #43 ControlMaster ControlPath ControlPersist --- tssh/args.go | 1 + tssh/control.go | 325 +++++++++++++++++++++++++++++++++++++++++++ tssh/ctrl_unix.go | 276 ++++++++++++++++++++++++++++++++++++ tssh/ctrl_windows.go | 50 +++++++ tssh/forward.go | 27 ++-- tssh/login.go | 77 +++++----- tssh/main.go | 1 + tssh/prompt.go | 2 +- tssh/tokens.go | 95 +++++++++++++ tssh/tokens_test.go | 93 +++++++++++++ 10 files changed, 905 insertions(+), 42 deletions(-) create mode 100644 tssh/control.go create mode 100644 tssh/ctrl_unix.go create mode 100644 tssh/ctrl_windows.go create mode 100644 tssh/tokens.go create mode 100644 tssh/tokens_test.go diff --git a/tssh/args.go b/tssh/args.go index 3bf57f1..24949f4 100644 --- a/tssh/args.go +++ b/tssh/args.go @@ -74,6 +74,7 @@ type sshArgs struct { Relay bool `arg:"--relay" help:"force trzsz run as a relay on the jump server"` Debug bool `arg:"--debug" help:"verbose mode for debugging, same as ssh's -vvv"` EncSecret bool `arg:"--enc-secret" help:"encode secret for configuration(~/.ssh/password)"` + originalDest string } func (sshArgs) Description() string { diff --git a/tssh/control.go b/tssh/control.go new file mode 100644 index 0000000..6726843 --- /dev/null +++ b/tssh/control.go @@ -0,0 +1,325 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tssh + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "sync" + _ "unsafe" + + "golang.org/x/crypto/ssh" +) + +//go:linkname newMux golang.org/x/crypto/ssh.newMux +func newMux(p packetConn) *mux + +//go:linkname muxSendRequest golang.org/x/crypto/ssh.(*mux).SendRequest +func muxSendRequest(m *mux, name string, wantReply bool, payload []byte) (bool, []byte, error) + +//go:linkname muxOpenChannel golang.org/x/crypto/ssh.(*mux).OpenChannel +func muxOpenChannel(m *mux, chanType string, extra []byte) (ssh.Channel, <-chan *ssh.Request, error) + +//go:linkname muxWait golang.org/x/crypto/ssh.(*mux).Wait +func muxWait(m *mux) error + +// packetConn represents a transport that implements packet based +// operations. +type packetConn interface { + // Encrypt and send a packet of data to the remote peer. + writePacket(packet []byte) error + + // Read a packet from the connection. The read is blocking, + // i.e. if error is nil, then the returned byte slice is + // always non-empty. + readPacket() ([]byte, error) + + // Close closes the write-side of the connection. + Close() error +} + +// channel is an implementation of the Channel interface that works +// with the mux class. +type channel struct{} // nolint:all + +// chanList is a thread safe channel list. +type chanList struct { // nolint:all + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel + + // This is a debugging aid: it offsets all IDs by this + // amount. This helps distinguish otherwise identical + // server/client muxes + offset uint32 +} + +// mux represents the state for the SSH connection protocol, which +// multiplexes many channels onto a single packet transport. +type mux struct { + conn packetConn // nolint:all + chanList chanList // nolint:all + + incomingChannels chan ssh.NewChannel + + globalSentMu sync.Mutex // nolint:all + globalResponses chan interface{} // nolint:all + incomingRequests chan *ssh.Request + + errCond *sync.Cond // nolint:all + err error // nolint:all +} + +type connTransport interface { + packetConn + getSessionID() []byte + waitSession() error +} + +// A connection represents an incoming connection. +type connection struct { + transport connTransport + sshConn + + // The connection protocol. + *mux +} + +func (c *connection) Close() error { + return c.sshConn.conn.Close() +} + +func (c *connection) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + return muxSendRequest(c.mux, name, wantReply, payload) +} + +func (c *connection) OpenChannel(chanType string, extra []byte) (ssh.Channel, <-chan *ssh.Request, error) { + return muxOpenChannel(c.mux, chanType, extra) +} + +func (c *connection) Wait() error { + return muxWait(c.mux) +} + +// sshConn provides net.Conn metadata, but disallows direct reads and +// writes. +type sshConn struct { + conn net.Conn + + user string + sessionID []byte + clientVersion []byte + serverVersion []byte +} + +func dup(src []byte) []byte { + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func (c *sshConn) User() string { + return c.user +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) Close() error { + return c.conn.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) SessionID() []byte { + return dup(c.sessionID) +} + +func (c *sshConn) ClientVersion() []byte { + return dup(c.clientVersion) +} + +func (c *sshConn) ServerVersion() []byte { + return dup(c.serverVersion) +} + +// NewControlClientConn establishes an SSH connection over an OpenSSH +// ControlMaster socket c in proxy mode. The Request and NewChannel channels +// must be serviced or the connection will hang. +func NewControlClientConn(c net.Conn) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) { + conn := &connection{ + sshConn: sshConn{conn: c}, + } + var err error + if conn.transport, err = handshakeControlProxy(c); err != nil { + return nil, nil, nil, fmt.Errorf("ssh: control proxy handshake failed; %v", err) + } + conn.mux = newMux(conn.transport) + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} + +const ( + muxMsgHello = 0x00000001 + muxCliProxy = 0x1000000f + muxSvrProxy = 0x8000000f + muxSFailure = 0x80000003 +) + +// handshakeControlProxy attempts to establish a transport connection with an +// OpenSSH ControlMaster socket in proxy mode. For details see: +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.mux +func handshakeControlProxy(rw io.ReadWriteCloser) (connTransport, error) { + b := &controlBuffer{} + b.writeUint32(muxMsgHello) + b.writeUint32(4) // Protocol Version + if _, err := rw.Write(b.lengthPrefixedBytes()); err != nil { + return nil, fmt.Errorf("mux hello write failed: %v", err) + } + + b.Reset() + b.writeUint32(muxCliProxy) + b.writeUint32(0) // Request ID + if _, err := rw.Write(b.lengthPrefixedBytes()); err != nil { + return nil, fmt.Errorf("mux client proxy write failed: %v", err) + } + + r := controlReader{rw} + m, err := r.next() + if err != nil { + return nil, fmt.Errorf("mux hello read failed: %v", err) + } + if m.messageType != muxMsgHello { + return nil, fmt.Errorf("mux reply not hello") + } + if v, err := m.readUint32(); err != nil || v != 4 { + return nil, fmt.Errorf("mux reply hello has bad protocol version") + } + m, err = r.next() + if err != nil { + return nil, fmt.Errorf("error reading mux server proxy: %v", err) + } + if m.messageType != muxSvrProxy { + return nil, fmt.Errorf("expected server proxy response got %d", m.messageType) + } + return &controlProxyTransport{rw}, nil +} + +// controlProxyTransport implements the connTransport interface for +// ControlMaster connections. Each controlMessage has zero length padding and +// no MAC. +type controlProxyTransport struct { + rw io.ReadWriteCloser +} + +func (p *controlProxyTransport) Close() error { + return p.rw.Close() +} + +func (p *controlProxyTransport) getSessionID() []byte { + return nil +} + +func (p *controlProxyTransport) readPacket() ([]byte, error) { + var l uint32 + err := binary.Read(p.rw, binary.BigEndian, &l) + if err == nil { + buf := &bytes.Buffer{} + _, err = io.CopyN(buf, p.rw, int64(l)) + if err == nil { + // Discard the padding byte. + _, _ = buf.ReadByte() + return buf.Bytes(), nil + } + } + return nil, err +} + +func (p *controlProxyTransport) writePacket(controlMessage []byte) error { + l := uint32(len(controlMessage)) + 1 + b := &bytes.Buffer{} + _ = binary.Write(b, binary.BigEndian, &l) // controlMessage Length. + b.WriteByte(0) // Padding Length. + b.Write(controlMessage) + _, err := p.rw.Write(b.Bytes()) + return err +} + +func (p *controlProxyTransport) waitSession() error { + return nil +} + +type controlBuffer struct { + bytes.Buffer +} + +func (b *controlBuffer) writeUint32(i uint32) { + _ = binary.Write(b, binary.BigEndian, i) +} + +func (b *controlBuffer) lengthPrefixedBytes() []byte { + b2 := &bytes.Buffer{} + _ = binary.Write(b2, binary.BigEndian, uint32(b.Len())) + b2.Write(b.Bytes()) + return b2.Bytes() +} + +type controlMessage struct { + body bytes.Buffer + messageType uint32 +} + +func (p controlMessage) readUint32() (uint32, error) { + var u uint32 + err := binary.Read(&p.body, binary.BigEndian, &u) + return u, err +} + +func (p controlMessage) readString() (string, error) { + var l uint32 + err := binary.Read(&p.body, binary.BigEndian, &l) + if err != nil { + return "", fmt.Errorf("error reading string length: %v", err) + } + b := p.body.Next(int(l)) + if len(b) != int(l) { + return string(b), fmt.Errorf("EOF on string read") + } + return string(b), nil +} + +type controlReader struct { + r io.Reader +} + +func (r controlReader) next() (*controlMessage, error) { + p := &controlMessage{} + var len uint32 + err := binary.Read(r.r, binary.BigEndian, &len) + if err != nil { + return nil, fmt.Errorf("error reading message length: %v", err) + } + _, err = io.CopyN(&p.body, r.r, int64(len)) + if err != nil { + return nil, fmt.Errorf("error reading message payload: %v", err) + } + err = binary.Read(&p.body, binary.BigEndian, &p.messageType) + if err != nil { + return nil, fmt.Errorf("error reading message type: %v", err) + } + if p.messageType == muxSFailure { + reason, _ := p.readString() + return nil, fmt.Errorf("server failure: '%s'", reason) + } + return p, nil +} diff --git a/tssh/ctrl_unix.go b/tssh/ctrl_unix.go new file mode 100644 index 0000000..9325630 --- /dev/null +++ b/tssh/ctrl_unix.go @@ -0,0 +1,276 @@ +//go:build !windows + +/* +MIT License + +Copyright (c) 2023 Lonny Wong +Copyright (c) 2023 [Contributors](https://github.com/trzsz/trzsz-ssh/graphs/contributors) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package tssh + +import ( + "bytes" + "fmt" + "io" + "net" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strconv" + "strings" + "sync/atomic" + "syscall" + "time" + + "golang.org/x/crypto/ssh" +) + +type controlMaster struct { + path string + args []string + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + stderr io.ReadCloser + loggingIn atomic.Bool + exited atomic.Bool +} + +func (c *controlMaster) readStderr() { + go func() { + defer c.stderr.Close() + buf := make([]byte, 100) + for c.loggingIn.Load() { + n, err := c.stderr.Read(buf) + if n > 0 { + fmt.Fprintf(os.Stderr, "%s", string(buf[:n])) + } + if err != nil { + break + } + } + }() +} + +func (c *controlMaster) readStdout() <-chan error { + done := make(chan error, 1) + go func() { + defer close(done) + buf := make([]byte, 1000) + n, err := c.stdout.Read(buf) + if err != nil { + done <- fmt.Errorf("stdout read failed: %v", err) + return + } + if !bytes.Equal(bytes.TrimSpace(buf[:n]), []byte("ok")) { + done <- fmt.Errorf("stdout invalid: %v", buf[:n]) + return + } + done <- nil + }() + return done +} + +func (c *controlMaster) checkExit() <-chan struct{} { + exit := make(chan struct{}, 1) + go func() { + defer close(exit) + _ = c.cmd.Wait() + c.exited.Store(true) + exit <- struct{}{} + }() + return exit +} + +func (c *controlMaster) start() error { + var err error + c.cmd = exec.Command(c.path, c.args...) + c.stdin, err = c.cmd.StdinPipe() + if err != nil { + return fmt.Errorf("stdin pipe failed: %v", err) + } + c.stdout, err = c.cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("stdout pipe failed: %v", err) + } + c.stderr, err = c.cmd.StderrPipe() + if err != nil { + return fmt.Errorf("stderr pipe failed: %v", err) + } + if err := c.cmd.Start(); err != nil { + return fmt.Errorf("start failed: %v", err) + } + + c.loggingIn.Store(true) + defer func() { + c.loggingIn.Store(false) + }() + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + defer func() { signal.Stop(interrupt); close(interrupt) }() + + c.readStderr() + exit := c.checkExit() + done := c.readStdout() + + onExitFuncs = append(onExitFuncs, func() { + c.quit(exit) + }) + + for { + select { + case err := <-done: + return err + case <-exit: + return fmt.Errorf("process exited") + case <-interrupt: + c.quit(exit) + return fmt.Errorf("interrupt") + } + } +} + +func (c *controlMaster) quit(exit <-chan struct{}) { + if c.exited.Load() { + return + } + _, _ = c.stdin.Write([]byte("\x03")) // ctrl + c + _ = c.cmd.Process.Signal(syscall.SIGTERM) + timer := time.AfterFunc(200*time.Millisecond, func() { + _ = c.cmd.Process.Kill() + }) + <-exit + timer.Stop() +} + +func getRealPath(path string) string { + realPath, err := filepath.EvalSymlinks(path) + if err != nil { + return path + } + return realPath +} + +func getOpenSSH() (string, error) { + sshPath := "/usr/bin/ssh" + tsshPath, err := os.Executable() + if err != nil { + return "", err + } + if getRealPath(tsshPath) == getRealPath(sshPath) { + return "", fmt.Errorf("%s is the current program", sshPath) + } + return sshPath, nil +} + +func startControlMaster(args *sshArgs) { + sshPath, err := getOpenSSH() + if err != nil { + warning("can't find ssh to start control master: %v", err) + return + } + + cmdArgs := []string{"-a", "-T", "-oClearAllForwardings=yes", "-oRemoteCommand=none", "-oConnectTimeout=5"} + if args.Debug { + cmdArgs = append(cmdArgs, "-v") + } + if args.LoginName != "" { + cmdArgs = append(cmdArgs, "-l", args.LoginName) + } + if args.Port != 0 { + cmdArgs = append(cmdArgs, "-p", strconv.Itoa(args.Port)) + } + for _, identity := range args.Identity.values { + cmdArgs = append(cmdArgs, "-i", identity) + } + if args.ConfigFile != "" { + cmdArgs = append(cmdArgs, "-F", args.ConfigFile) + } + if args.ProxyJump != "" { + cmdArgs = append(cmdArgs, "-J", args.ProxyJump) + } + + for key, value := range args.Option.options { + switch key { + case "controlmaster": + cmdArgs = append(cmdArgs, fmt.Sprintf("-oControlMaster=%s", value)) + case "controlpath": + cmdArgs = append(cmdArgs, fmt.Sprintf("-oControlPath=%s", value)) + case "controlpersist": + cmdArgs = append(cmdArgs, fmt.Sprintf("-oControlPersist=%s", value)) + } + } + + if args.originalDest != "" { + cmdArgs = append(cmdArgs, args.originalDest) + } else { + cmdArgs = append(cmdArgs, args.Destination) + } + // sleep 2147483 for PowerShell + cmdArgs = append(cmdArgs, "echo ok; sleep 2147483; sleep infinity") + + if enableDebugLogging { + debug("control master: %s %s", sshPath, strings.Join(cmdArgs, " ")) + } + + ctrlMaster := &controlMaster{path: sshPath, args: cmdArgs} + if err := ctrlMaster.start(); err != nil { + warning("start control master failed: %v", err) + return + } + debug("start control master success") +} + +func connectViaControl(args *sshArgs, param *loginParam) *ssh.Client { + ctrlMaster := getOptionConfig(args, "ControlMaster") + ctrlPath := getOptionConfig(args, "ControlPath") + + switch strings.ToLower(ctrlMaster) { + case "auto", "yes", "ask", "autoask": + startControlMaster(args) + } + + switch strings.ToLower(ctrlPath) { + case "", "none": + return nil + } + + unixAddr := resolveHomeDir(expandTokens(ctrlPath, args, param, "%CdhikLlnpru")) + debug("login to [%s], socket: %s", args.Destination, unixAddr) + + conn, err := net.DialTimeout("unix", unixAddr, time.Second) + if err != nil { + warning("dial ctrl unix [%s] failed: %v", unixAddr, err) + return nil + } + + ncc, chans, reqs, err := NewControlClientConn(conn) + if err != nil { + warning("new ctrl conn [%s] failed: %v", unixAddr, err) + return nil + } + + debug("login to [%s] success", args.Destination) + return ssh.NewClient(ncc, chans, reqs) +} diff --git a/tssh/ctrl_windows.go b/tssh/ctrl_windows.go new file mode 100644 index 0000000..f77a4cf --- /dev/null +++ b/tssh/ctrl_windows.go @@ -0,0 +1,50 @@ +/* +MIT License + +Copyright (c) 2023 Lonny Wong +Copyright (c) 2023 [Contributors](https://github.com/trzsz/trzsz-ssh/graphs/contributors) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package tssh + +import ( + "strings" + + "golang.org/x/crypto/ssh" +) + +func connectViaControl(args *sshArgs, param *loginParam) *ssh.Client { + ctrlMaster := getOptionConfig(args, "ControlMaster") + ctrlPath := getOptionConfig(args, "ControlPath") + + switch strings.ToLower(ctrlMaster) { + case "auto", "yes", "ask", "autoask": + warning("ControlMaster is not supported on Windows") + } + + switch strings.ToLower(ctrlPath) { + case "", "none": + return nil + } + + warning("ControlPath is not supported on Windows") + return nil +} diff --git a/tssh/forward.go b/tssh/forward.go index 5853521..24e0ba7 100644 --- a/tssh/forward.go +++ b/tssh/forward.go @@ -379,10 +379,25 @@ func remoteForward(client *ssh.Client, f *forwardCfg, args *sshArgs) { } func sshForward(client *ssh.Client, args *sshArgs) error { - // dynamic forward + // command dynamic forward for _, b := range args.DynamicForward.binds { dynamicForward(client, b, args) } + // command local forward + for _, f := range args.LocalForward.cfgs { + localForward(client, f, args) + } + // command remote forward + for _, f := range args.RemoteForward.cfgs { + remoteForward(client, f, args) + } + + // clear all forwardings + if strings.ToLower(args.Option.get("ClearAllForwardings")) == "yes" { + return nil + } + + // config dynamic forward for _, s := range getAllConfig(args.Destination, "DynamicForward") { b, err := parseBindCfg(s) if err != nil { @@ -392,10 +407,7 @@ func sshForward(client *ssh.Client, args *sshArgs) error { dynamicForward(client, b, args) } - // local forward - for _, f := range args.LocalForward.cfgs { - localForward(client, f, args) - } + // config local forward for _, s := range getAllConfig(args.Destination, "LocalForward") { f, err := parseForwardCfg(s) if err != nil { @@ -405,10 +417,7 @@ func sshForward(client *ssh.Client, args *sshArgs) error { localForward(client, f, args) } - // remote forward - for _, f := range args.RemoteForward.cfgs { - remoteForward(client, f, args) - } + // config remote forward for _, s := range getAllConfig(args.Destination, "RemoteForward") { f, err := parseForwardCfg(s) if err != nil { diff --git a/tssh/login.go b/tssh/login.go index d7afa56..82c419a 100644 --- a/tssh/login.go +++ b/tssh/login.go @@ -56,7 +56,7 @@ func debug(format string, a ...any) { fmt.Fprintf(os.Stderr, fmt.Sprintf("\033[0;36mdebug:\033[0m %s\r\n", format), a...) } -func warning(format string, a ...any) { +var warning = func(format string, a ...any) { fmt.Fprintf(os.Stderr, fmt.Sprintf("\033[0;33mWarning: %s\033[0m\r\n", format), a...) } @@ -318,7 +318,9 @@ func (s *sshSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { if err := s.initSigner(); err != nil { return nil, err } - debug("sign without algorithm: %s", ssh.FingerprintSHA256(s.pubKey)) + if enableDebugLogging { + debug("sign without algorithm: %s", ssh.FingerprintSHA256(s.pubKey)) + } return s.signer.Sign(rand, data) } @@ -327,10 +329,14 @@ func (s *sshSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm str return nil, err } if signer, ok := s.signer.(ssh.AlgorithmSigner); ok { - debug("sign with algorithm [%s]: %s", algorithm, ssh.FingerprintSHA256(s.pubKey)) + if enableDebugLogging { + debug("sign with algorithm [%s]: %s", algorithm, ssh.FingerprintSHA256(s.pubKey)) + } return signer.SignWithAlgorithm(rand, data, algorithm) } - debug("sign without algorithm: %s", ssh.FingerprintSHA256(s.pubKey)) + if enableDebugLogging { + debug("sign without algorithm: %s", ssh.FingerprintSHA256(s.pubKey)) + } return s.signer.Sign(rand, data) } @@ -500,7 +506,9 @@ func getPublicKeysAuthMethod(args *sshArgs) ssh.AuthMethod { for _, signer := range signers { fingerprint := ssh.FingerprintSHA256(signer.PublicKey()) if _, ok := fingerprints[fingerprint]; !ok { - debug("will attempt key: %s %s %s", signer.path, signer.pubKey.Type(), ssh.FingerprintSHA256(signer.pubKey)) + if enableDebugLogging { + debug("will attempt key: %s %s %s", signer.path, signer.pubKey.Type(), ssh.FingerprintSHA256(signer.pubKey)) + } fingerprints[fingerprint] = struct{}{} pubKeySigners = append(pubKeySigners, signer) } @@ -621,11 +629,8 @@ func (p *cmdPipe) Close() error { return err2 } -func execProxyCommand(param *loginParam) (net.Conn, string, error) { - command := param.command - command = strings.ReplaceAll(command, "%h", param.host) - command = strings.ReplaceAll(command, "%p", param.port) - command = strings.ReplaceAll(command, "%r", param.user) +func execProxyCommand(args *sshArgs, param *loginParam) (net.Conn, string, error) { + command := resolveHomeDir(expandTokens(param.command, args, param, "%hnpr")) debug("exec proxy command: %s", command) argv, err := splitCommandLine(command) @@ -694,15 +699,20 @@ func (c *connWithTimeout) Read(b []byte) (n int, err error) { return } -func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, error) { +func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, bool, error) { param, err := getLoginParam(args) if err != nil { - return nil, err + return nil, false, err } + + if client := connectViaControl(args, param); client != nil { + return client, true, nil + } + authMethods := getAuthMethods(args, param.host, param.user) cb, kh, err := getHostKeyCallback() if err != nil { - return nil, err + return nil, false, err } config := &ssh.ClientConfig{ User: param.user, @@ -716,18 +726,18 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, e }, } - proxyConnect := func(client *ssh.Client, proxy string) (*ssh.Client, error) { + proxyConnect := func(client *ssh.Client, proxy string) (*ssh.Client, bool, error) { debug("login to [%s], addr: %s", args.Destination, param.addr) conn, err := dialWithTimeout(client, "tcp", param.addr, 10*time.Second) if err != nil { - return nil, fmt.Errorf("proxy [%s] dial tcp [%s] failed: %v", proxy, param.addr, err) + return nil, false, fmt.Errorf("proxy [%s] dial tcp [%s] failed: %v", proxy, param.addr, err) } ncc, chans, reqs, err := ssh.NewClientConn(&connWithTimeout{conn, config.Timeout, true}, param.addr, config) if err != nil { - return nil, fmt.Errorf("proxy [%s] new conn [%s] failed: %v", proxy, param.addr, err) + return nil, false, fmt.Errorf("proxy [%s] new conn [%s] failed: %v", proxy, param.addr, err) } debug("login to [%s] success", args.Destination) - return ssh.NewClient(ncc, chans, reqs), nil + return ssh.NewClient(ncc, chans, reqs), false, nil } // has parent client @@ -738,16 +748,16 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, e // proxy command if param.command != "" { debug("login to [%s], addr: %s", args.Destination, param.addr) - conn, cmd, err := execProxyCommand(param) + conn, cmd, err := execProxyCommand(args, param) if err != nil { - return nil, fmt.Errorf("exec proxy command [%s] failed: %v", cmd, err) + return nil, false, fmt.Errorf("exec proxy command [%s] failed: %v", cmd, err) } ncc, chans, reqs, err := ssh.NewClientConn(conn, param.addr, config) if err != nil { - return nil, fmt.Errorf("proxy command [%s] new conn [%s] failed: %v", cmd, param.addr, err) + return nil, false, fmt.Errorf("proxy command [%s] new conn [%s] failed: %v", cmd, param.addr, err) } debug("login to [%s] success", args.Destination) - return ssh.NewClient(ncc, chans, reqs), nil + return ssh.NewClient(ncc, chans, reqs), false, nil } // no proxy @@ -755,22 +765,22 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, e debug("login to [%s], addr: %s", args.Destination, param.addr) conn, err := net.DialTimeout("tcp", param.addr, config.Timeout) if err != nil { - return nil, fmt.Errorf("dial tcp [%s] failed: %v", param.addr, err) + return nil, false, fmt.Errorf("dial tcp [%s] failed: %v", param.addr, err) } ncc, chans, reqs, err := ssh.NewClientConn(&connWithTimeout{conn, config.Timeout, true}, param.addr, config) if err != nil { - return nil, fmt.Errorf("new conn [%s] failed: %v", param.addr, err) + return nil, false, fmt.Errorf("new conn [%s] failed: %v", param.addr, err) } debug("login to [%s] success", args.Destination) - return ssh.NewClient(ncc, chans, reqs), nil + return ssh.NewClient(ncc, chans, reqs), false, nil } // has proxies var proxyClient *ssh.Client for _, proxy = range param.proxy { - proxyClient, err = sshConnect(&sshArgs{Destination: proxy}, proxyClient, proxy) + proxyClient, _, err = sshConnect(&sshArgs{Destination: proxy}, proxyClient, proxy) if err != nil { - return nil, err + return nil, false, err } } return proxyConnect(proxyClient, proxy) @@ -813,14 +823,14 @@ func keepAlive(client *ssh.Client, args *sshArgs) { } func sshAgentForward(args *sshArgs, client *ssh.Client, session *ssh.Session) { - agentClient := getAgentClient() - if agentClient == nil { - return - } if args.NoForwardAgent || !args.ForwardAgent && strings.ToLower(getOptionConfig(args, "ForwardAgent")) != "yes" { closeAgentClient() return } + agentClient := getAgentClient() + if agentClient == nil { + return + } if err := agent.ForwardToAgent(client, agentClient); err != nil { warning("forward to agent failed: %v", err) return @@ -850,13 +860,16 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session }) // ssh login - client, err = sshConnect(args, nil, "") + var control bool + client, control, err = sshConnect(args, nil, "") if err != nil { return } // keep alive - keepAlive(client, args) + if !control { + keepAlive(client, args) + } // no command if args.NoCommand || args.StdioForward != "" { diff --git a/tssh/main.go b/tssh/main.go index d1a5b0f..8ea1e2c 100644 --- a/tssh/main.go +++ b/tssh/main.go @@ -243,6 +243,7 @@ func TsshMain() int { } } args.Destination = dest + args.originalDest = dest // start ssh program if err = sshStart(&args); err != nil { diff --git a/tssh/prompt.go b/tssh/prompt.go index 231d1c6..d5c6ded 100644 --- a/tssh/prompt.go +++ b/tssh/prompt.go @@ -645,7 +645,7 @@ func fastLookupHost(host string) bool { } func predictDestination(dest string) (string, bool, error) { - if strings.ContainsRune(dest, '.') || strings.ContainsRune(dest, ':') { + if strings.ContainsAny(dest, ".:[]@") { return dest, false, nil } diff --git a/tssh/tokens.go b/tssh/tokens.go new file mode 100644 index 0000000..37d3401 --- /dev/null +++ b/tssh/tokens.go @@ -0,0 +1,95 @@ +/* +MIT License + +Copyright (c) 2023 Lonny Wong +Copyright (c) 2023 [Contributors](https://github.com/trzsz/trzsz-ssh/graphs/contributors) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package tssh + +import ( + "crypto/sha1" + "fmt" + "os" + "strings" +) + +var getHostname = func() string { + hostname, err := os.Hostname() + if err != nil { + warning("get hostname failed: %v", err) + return "" + } + return hostname +} + +func expandTokens(str string, args *sshArgs, param *loginParam, tokens string) string { + var buf strings.Builder + state := byte(0) + for _, c := range str { + if state == 0 { + if c == '%' { + state = '%' + continue + } + buf.WriteRune(c) + continue + } + state = 0 + if !strings.ContainsRune(tokens, c) { + warning("token [%%%c] in [%s] is not supported", c, str) + buf.WriteRune('%') + buf.WriteRune(c) + continue + } + switch c { + case '%': + buf.WriteRune('%') + case 'h': + buf.WriteString(param.host) + case 'p': + buf.WriteString(param.port) + case 'r': + buf.WriteString(param.user) + case 'n': + buf.WriteString(args.Destination) + case 'l': + buf.WriteString(getHostname()) + case 'L': + hostname := getHostname() + if idx := strings.IndexByte(hostname, '.'); idx >= 0 { + hostname = hostname[:idx] + } + buf.WriteString(hostname) + case 'C': + hashStr := fmt.Sprintf("%s%s%s%s", getHostname(), param.host, param.port, param.user) + buf.WriteString(fmt.Sprintf("%x", sha1.Sum([]byte(hashStr)))) + default: + warning("token [%%%c] in [%s] is not supported yet", c, str) + buf.WriteRune('%') + buf.WriteRune(c) + } + } + if state != 0 { + warning("[%s] ends with %% is invalid", str) + } + return buf.String() +} diff --git a/tssh/tokens_test.go b/tssh/tokens_test.go new file mode 100644 index 0000000..1432511 --- /dev/null +++ b/tssh/tokens_test.go @@ -0,0 +1,93 @@ +/* +MIT License + +Copyright (c) 2023 Lonny Wong +Copyright (c) 2023 [Contributors](https://github.com/trzsz/trzsz-ssh/graphs/contributors) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package tssh + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExpandTokens(t *testing.T) { + assert := assert.New(t) + originalWarning := warning + defer func() { + warning = originalWarning + }() + var output string + warning = func(format string, a ...any) { + output = fmt.Sprintf(format, a...) + } + originalGetHostname := getHostname + defer func() { + getHostname = originalGetHostname + }() + getHostname = func() string { return "myhostname.mydomain.com" } + + args := &sshArgs{ + Destination: "dest", + } + param := &loginParam{ + host: "127.0.0.1", + port: "1337", + user: "penny", + } + assertProxyCommand := func(original, expanded, result string) { + t.Helper() + output = "" + assert.Equal(expanded, expandTokens(original, args, param, "%hnpr")) + assert.Equal(result, output) + } + + assertProxyCommand("%%", "%", "") + assertProxyCommand("%h", "127.0.0.1", "") + assertProxyCommand("%n", "dest", "") + assertProxyCommand("%p", "1337", "") + assertProxyCommand("%r", "penny", "") + assertProxyCommand("a_%%_%r_%p_%n_%h_Z", "a_%_penny_1337_dest_127.0.0.1_Z", "") + + assertProxyCommand("%l", "%l", "token [%l] in [%l] is not supported") + assertProxyCommand("a_%h_%C", "a_127.0.0.1_%C", "token [%C] in [a_%h_%C] is not supported") + + assertControlPath := func(original, expanded, result string) { + t.Helper() + output = "" + assert.Equal(expanded, expandTokens(original, args, param, "%CdhikLlnpru")) + assert.Equal(result, output) + } + + assertControlPath("%p和%r", "1337和penny", "") + assertControlPath("%%%h%n", "%127.0.0.1dest", "") + assertControlPath("%L", "myhostname", "") + assertControlPath("%l", "myhostname.mydomain.com", "") + + assertControlPath("/A/%C/B", "/A/07f25c03a322b120bcaa54d2dd0a618f2673cb1c/B", "") + + assertControlPath("%j", "%j", "token [%j] in [%j] is not supported") + assertControlPath("p_%h_%d", "p_127.0.0.1_%d", "token [%d] in [p_%h_%d] is not supported yet") + assertControlPath("h%", "h", "[h%] ends with % is invalid") +}