From 6c684a0849b1af15cc01780b0ec89be432faca3b Mon Sep 17 00:00:00 2001 From: Lonny Wong Date: Sat, 27 Jan 2024 20:09:52 +0800 Subject: [PATCH] expand supported options --- tssh/agent.go | 24 ++++-- tssh/ctrl_unix.go | 4 +- tssh/ctrl_windows.go | 2 +- tssh/expect.go | 16 ++-- tssh/forward.go | 16 +++- tssh/login.go | 193 ++++++++++++++++++++++++++++++++++--------- tssh/main.go | 91 ++++---------------- tssh/tokens.go | 5 +- tssh/tokens_test.go | 6 +- tssh/trzsz.go | 26 +++--- 10 files changed, 225 insertions(+), 158 deletions(-) diff --git a/tssh/agent.go b/tssh/agent.go index 8cc699d..757b7a4 100644 --- a/tssh/agent.go +++ b/tssh/agent.go @@ -41,25 +41,33 @@ var ( agentClient agent.ExtendedAgent ) -func getAgentAddr(args *sshArgs) string { +func getAgentAddr(args *sshArgs, param *sshParam) (string, error) { if addr := getOptionConfig(args, "IdentityAgent"); addr != "" { if strings.ToLower(addr) == "none" { - return "" + return "", nil } - return addr + expandedAddr, err := expandTokens(addr, args, param, "%CdhikLlnpru") + if err != nil { + return "", fmt.Errorf("expand IdentityAgent [%s] failed: %v", addr, err) + } + return resolveHomeDir(expandedAddr), nil } if addr := os.Getenv("SSH_AUTH_SOCK"); addr != "" { - return addr + return resolveHomeDir(addr), nil } if addr := defaultAgentAddr; addr != "" && isFileExist(addr) { - return addr + return addr, nil } - return "" + return "", nil } -func getAgentClient(args *sshArgs) agent.ExtendedAgent { +func getAgentClient(args *sshArgs, param *sshParam) agent.ExtendedAgent { agentOnce.Do(func() { - addr := resolveHomeDir(getAgentAddr(args)) + addr, err := getAgentAddr(args, param) + if err != nil { + warning("get agent addr failed: %v", err) + return + } if addr == "" { debug("ssh agent address is not set") return diff --git a/tssh/ctrl_unix.go b/tssh/ctrl_unix.go index c2782e7..f004efb 100644 --- a/tssh/ctrl_unix.go +++ b/tssh/ctrl_unix.go @@ -311,7 +311,7 @@ func startControlMaster(args *sshArgs) error { return nil } -func connectViaControl(args *sshArgs, param *loginParam) *ssh.Client { +func connectViaControl(args *sshArgs, param *sshParam) *ssh.Client { ctrlMaster := getOptionConfig(args, "ControlMaster") ctrlPath := getOptionConfig(args, "ControlPath") @@ -322,7 +322,7 @@ func connectViaControl(args *sshArgs, param *loginParam) *ssh.Client { socket, err := expandTokens(ctrlPath, args, param, "%CdhikLlnpru") if err != nil { - warning("expand control socket [%s] failed: %v", socket, err) + warning("expand ControlPath [%s] failed: %v", socket, err) return nil } socket = resolveHomeDir(socket) diff --git a/tssh/ctrl_windows.go b/tssh/ctrl_windows.go index 7595ddb..e5c6e6a 100644 --- a/tssh/ctrl_windows.go +++ b/tssh/ctrl_windows.go @@ -30,7 +30,7 @@ import ( "golang.org/x/crypto/ssh" ) -func connectViaControl(args *sshArgs, param *loginParam) *ssh.Client { +func connectViaControl(args *sshArgs, param *sshParam) *ssh.Client { ctrlMaster := getOptionConfig(args, "ControlMaster") ctrlPath := getOptionConfig(args, "ControlPath") diff --git a/tssh/expect.go b/tssh/expect.go index 861e41a..62a2bcb 100644 --- a/tssh/expect.go +++ b/tssh/expect.go @@ -430,11 +430,10 @@ func getExpectTimeout(args *sshArgs, prefix string) int { return int(count) } -func execExpectInteractions(args *sshArgs, serverIn io.Writer, - serverOut io.Reader, serverErr io.Reader) (io.Reader, io.Reader) { +func execExpectInteractions(args *sshArgs, ss *sshSession) { expectCount := getExpectCount(args, "") if expectCount <= 0 { - return serverOut, serverErr + return } outReader, outWriter := io.Pipe() @@ -456,15 +455,16 @@ func execExpectInteractions(args *sshArgs, serverIn io.Writer, out: make(chan []byte, 10), err: make(chan []byte, 10), } - go expect.wrapOutput(serverOut, outWriter, expect.out) - go expect.wrapOutput(serverErr, errWriter, expect.err) + go expect.wrapOutput(ss.serverOut, outWriter, expect.out) + go expect.wrapOutput(ss.serverErr, errWriter, expect.err) - expect.execInteractions(serverIn, expectCount) + expect.execInteractions(ss.serverIn, expectCount) if ctx.Err() == context.DeadlineExceeded { warning("expect timeout after %d seconds", expectTimeout) - _, _ = serverIn.Write([]byte("\r")) // enter for shell prompt if timeout + _, _ = ss.serverIn.Write([]byte("\r")) // enter for shell prompt if timeout } - return outReader, errReader + ss.serverOut = outReader + ss.serverErr = errReader } diff --git a/tssh/forward.go b/tssh/forward.go index e151199..baf39b9 100644 --- a/tssh/forward.go +++ b/tssh/forward.go @@ -382,7 +382,7 @@ func remoteForward(client *ssh.Client, f *forwardCfg, args *sshArgs) { } } -func sshForward(client *ssh.Client, args *sshArgs) error { +func sshForward(client *ssh.Client, args *sshArgs, param *sshParam) error { // clear all forwardings if strings.ToLower(getOptionConfig(args, "ClearAllForwardings")) == "yes" { return nil @@ -406,7 +406,12 @@ func sshForward(client *ssh.Client, args *sshArgs) error { localForward(client, f, args) } for _, s := range getAllOptionConfig(args, "LocalForward") { - f, err := parseForwardCfg(s) + es, err := expandTokens(s, args, param, "%CdhikLlnpru") + if err != nil { + warning("expand LocalForward [%s] failed: %v", s, err) + continue + } + f, err := parseForwardCfg(es) if err != nil { warning("local forward failed: %v", err) continue @@ -419,7 +424,12 @@ func sshForward(client *ssh.Client, args *sshArgs) error { remoteForward(client, f, args) } for _, s := range getAllOptionConfig(args, "RemoteForward") { - f, err := parseForwardCfg(s) + es, err := expandTokens(s, args, param, "%CdhikLlnpru") + if err != nil { + warning("expand RemoteForward [%s] failed: %v", s, err) + continue + } + f, err := parseForwardCfg(es) if err != nil { warning("remote forward failed: %v", err) continue diff --git a/tssh/login.go b/tssh/login.go index 354eb52..50667dd 100644 --- a/tssh/login.go +++ b/tssh/login.go @@ -42,6 +42,7 @@ import ( "sync/atomic" "time" + "github.com/alessio/shellescape" "github.com/skeema/knownhosts" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -65,7 +66,7 @@ var warning = func(format string, a ...any) { fmt.Fprintf(os.Stderr, fmt.Sprintf("\033[0;33mWarning: %s\033[0m\r\n", format), a...) } -type loginParam struct { +type sshParam struct { host string port string user string @@ -74,6 +75,28 @@ type loginParam struct { command string } +type sshSession struct { + client *ssh.Client + session *ssh.Session + serverIn io.WriteCloser + serverOut io.Reader + serverErr io.Reader + cmd string + tty bool +} + +func (s *sshSession) Close() { + if s.serverIn != nil { + s.serverIn.Close() + } + if s.session != nil { + s.session.Close() + } + if s.client != nil { + s.client.Close() + } +} + func joinHostPort(host, port string) string { if !strings.HasPrefix(host, "[") && strings.ContainsRune(host, ':') { return fmt.Sprintf("[%s]:%s", host, port) @@ -106,8 +129,8 @@ func parseDestination(dest string) (user, host, port string) { return } -func getLoginParam(args *sshArgs) (*loginParam, error) { - param := &loginParam{} +func getSshParam(args *sshArgs) (*sshParam, error) { + param := &sshParam{} // login dest destUser, destHost, destPort := parseDestination(args.Destination) @@ -183,6 +206,21 @@ func getLoginParam(args *sshArgs) (*loginParam, error) { } } + // expand proxy + var err error + if param.command != "" { + param.command, err = expandTokens(param.command, args, param, "%hnpr") + if err != nil { + return nil, fmt.Errorf("expand ProxyCommand [%s] failed: %v", param.command, err) + } + } + for i := 0; i < len(param.proxy); i++ { + param.proxy[i], err = expandTokens(param.proxy[i], args, param, "%hnpr") + if err != nil { + return nil, fmt.Errorf("expand ProxyJump [%s] failed: %v", param.proxy[i], err) + } + } + return param, nil } @@ -271,7 +309,7 @@ func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey, ask bool) return nil } -func getHostKeyCallback(args *sshArgs, param *loginParam) (ssh.HostKeyCallback, knownhosts.HostKeyCallback, error) { +func getHostKeyCallback(args *sshArgs, param *sshParam) (ssh.HostKeyCallback, knownhosts.HostKeyCallback, error) { primaryPath := "" var files []string addKnownHostsFiles := func(key string, user bool) error { @@ -285,7 +323,7 @@ func getHostKeyCallback(args *sshArgs, param *loginParam) (ssh.HostKeyCallback, if user { expandedPath, err := expandTokens(path, args, param, "%CdhikLlnpru") if err != nil { - return err + return fmt.Errorf("expand UserKnownHostsFile [%s] failed: %v", path, err) } resolvedPath = resolveHomeDir(expandedPath) if primaryPath == "" { @@ -632,7 +670,7 @@ var getDefaultSigners = func() func() []*sshSigner { } }() -func getPublicKeysAuthMethod(args *sshArgs) ssh.AuthMethod { +func getPublicKeysAuthMethod(args *sshArgs, param *sshParam) ssh.AuthMethod { if strings.ToLower(getOptionConfig(args, "PubkeyAuthentication")) == "no" { debug("disable auth method: public key authentication") return nil @@ -653,7 +691,7 @@ func getPublicKeysAuthMethod(args *sshArgs) ssh.AuthMethod { } } - if agentClient := getAgentClient(args); agentClient != nil { + if agentClient := getAgentClient(args, param); agentClient != nil { signers, err := agentClient.Signers() if err != nil { warning("get ssh agent signers failed: %v", err) @@ -664,7 +702,16 @@ func getPublicKeysAuthMethod(args *sshArgs) ssh.AuthMethod { } } - identities := append(args.Identity.values, getAllOptionConfig(args, "IdentityFile")...) + identities := args.Identity.values + for _, identity := range getAllOptionConfig(args, "IdentityFile") { + expandedIdentity, err := expandTokens(identity, args, param, "%CdhikLlnpru") + if err != nil { + warning("expand IdentityFile [%s] failed: %v", identity, err) + continue + } + identities = append(identities, expandedIdentity) + } + if len(identities) == 0 { addPubKeySigners(getDefaultSigners()) } else { @@ -681,17 +728,17 @@ func getPublicKeysAuthMethod(args *sshArgs) ssh.AuthMethod { return ssh.PublicKeys(pubKeySigners...) } -func getAuthMethods(args *sshArgs, host, user string) []ssh.AuthMethod { +func getAuthMethods(args *sshArgs, param *sshParam) []ssh.AuthMethod { var authMethods []ssh.AuthMethod - if authMethod := getPublicKeysAuthMethod(args); authMethod != nil { + if authMethod := getPublicKeysAuthMethod(args, param); authMethod != nil { debug("add auth method: public key authentication") authMethods = append(authMethods, authMethod) } - if authMethod := getKeyboardInteractiveAuthMethod(args, host, user); authMethod != nil { + if authMethod := getKeyboardInteractiveAuthMethod(args, param.host, param.user); authMethod != nil { debug("add auth method: keyboard interactive authentication") authMethods = append(authMethods, authMethod) } - if authMethod := getPasswordAuthMethod(args, host, user); authMethod != nil { + if authMethod := getPasswordAuthMethod(args, param.host, param.user); authMethod != nil { debug("add auth method: password authentication") authMethods = append(authMethods, authMethod) } @@ -753,7 +800,7 @@ func (p *cmdPipe) Close() error { return err2 } -func execProxyCommand(args *sshArgs, param *loginParam) (net.Conn, string, error) { +func execProxyCommand(args *sshArgs, param *sshParam) (net.Conn, string, error) { command, err := expandTokens(param.command, args, param, "%hnpr") if err != nil { return nil, param.command, err @@ -787,7 +834,7 @@ func execProxyCommand(args *sshArgs, param *loginParam) (net.Conn, string, error return &cmdPipe{stdin: cmdIn, stdout: cmdOut, addr: param.addr}, command, nil } -func execLocalCommand(args *sshArgs, param *loginParam) { +func execLocalCommand(args *sshArgs, param *sshParam) { if strings.ToLower(getOptionConfig(args, "PermitLocalCommand")) != "yes" { return } @@ -797,7 +844,7 @@ func execLocalCommand(args *sshArgs, param *loginParam) { } expandedCmd, err := expandTokens(localCmd, args, param, "%CdfHhIiKkLlnprTtu") if err != nil { - warning("expand local command [%s] failed: %v", localCmd, err) + warning("expand LocalCommand [%s] failed: %v", localCmd, err) return } resolvedCmd := resolveHomeDir(expandedCmd) @@ -822,6 +869,65 @@ func execLocalCommand(args *sshArgs, param *loginParam) { } } +func parseRemoteCommand(args *sshArgs, param *sshParam) (string, error) { + command := args.Option.get("RemoteCommand") + if args.Command != "" && command != "" && strings.ToLower(command) != "none" { + return "", fmt.Errorf("cannot execute command-line and remote command") + } + if args.Command != "" { + if len(args.Argument) == 0 { + return args.Command, nil + } + return shellescape.QuoteCommand(append([]string{args.Command}, args.Argument...)), nil + } + if strings.ToLower(command) == "none" { + return "", nil + } + if command == "" { + command = getConfig(args.Destination, "RemoteCommand") + } + expandedCmd, err := expandTokens(command, args, param, "%CdhikLlnpru") + if err != nil { + return "", fmt.Errorf("expand RemoteCommand [%s] failed: %v", command, err) + } + return expandedCmd, nil +} + +func parseCmdAndTTY(args *sshArgs, param *sshParam) (cmd string, tty bool, err error) { + cmd, err = parseRemoteCommand(args, param) + if err != nil { + return + } + + if args.DisableTTY && args.ForceTTY { + err = fmt.Errorf("cannot specify -t with -T") + return + } + if args.DisableTTY { + tty = false + return + } + if args.ForceTTY { + tty = true + return + } + + requestTTY := getConfig(args.Destination, "RequestTTY") + switch strings.ToLower(requestTTY) { + case "", "auto": + tty = isTerminal && (cmd == "") + case "no": + tty = false + case "force": + tty = true + case "yes": + tty = isTerminal + default: + err = fmt.Errorf("unknown RequestTTY option: %s", requestTTY) + } + return +} + func dialWithTimeout(client *ssh.Client, network, addr string, timeout time.Duration) (conn net.Conn, err error) { done := make(chan struct{}, 1) go func() { @@ -890,8 +996,8 @@ func setupLogLevel(args *sshArgs) func() { return reset } -func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, *loginParam, bool, error) { - param, err := getLoginParam(args) +func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, *sshParam, bool, error) { + param, err := getSshParam(args) if err != nil { return nil, nil, false, err } @@ -903,7 +1009,7 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, * return client, param, true, nil } - authMethods := getAuthMethods(args, param.host, param.user) + authMethods := getAuthMethods(args, param) cb, kh, err := getHostKeyCallback(args, param) if err != nil { return nil, param, false, err @@ -920,7 +1026,7 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, * }, } - proxyConnect := func(client *ssh.Client, proxy string) (*ssh.Client, *loginParam, bool, error) { + proxyConnect := func(client *ssh.Client, proxy string) (*ssh.Client, *sshParam, bool, error) { debug("login to [%s], addr: %s", args.Destination, param.addr) conn, err := dialWithTimeout(client, "tcp", param.addr, 10*time.Second) if err != nil { @@ -1016,11 +1122,15 @@ func keepAlive(client *ssh.Client, args *sshArgs) { }() } -func sshAgentForward(args *sshArgs, client *ssh.Client, session *ssh.Session) { +func sshAgentForward(args *sshArgs, param *sshParam, client *ssh.Client, session *ssh.Session) { if args.NoForwardAgent || !args.ForwardAgent && strings.ToLower(getOptionConfig(args, "ForwardAgent")) != "yes" { return } - addr := resolveHomeDir(getAgentAddr(args)) + addr, err := getAgentAddr(args, param) + if err != nil { + warning("get agent addr failed: %v", err) + return + } if addr == "" { warning("forward agent but the socket address is not set") return @@ -1036,17 +1146,12 @@ func sshAgentForward(args *sshArgs, client *ssh.Client, session *ssh.Session) { debug("request ssh agent forwarding success") } -func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session, - serverIn io.WriteCloser, serverOut io.Reader, serverErr io.Reader, err error) { - var param *loginParam +func sshLogin(args *sshArgs) (ss *sshSession, err error) { + ss = &sshSession{} + var param *sshParam defer func() { if err != nil { - if session != nil { - session.Close() - } - if client != nil { - client.Close() - } + ss.Close() } else { sshLoginSuccess.Store(true) // execute local command if necessary @@ -1056,14 +1161,20 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session // ssh login var control bool - client, param, control, err = sshConnect(args, nil, "") + ss.client, param, control, err = sshConnect(args, nil, "") + if err != nil { + return + } + + // parse cmd and tty + ss.cmd, ss.tty, err = parseCmdAndTTY(args, param) if err != nil { return } // keep alive if !control { - keepAlive(client, args) + keepAlive(ss.client, args) } // stdio forward @@ -1073,7 +1184,7 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session // ssh forward if !control { - if err = sshForward(client, args); err != nil { + if err = sshForward(ss.client, args, param); err != nil { return } } @@ -1084,29 +1195,29 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session } // new session - session, err = client.NewSession() + ss.session, err = ss.client.NewSession() if err != nil { err = fmt.Errorf("ssh new session failed: %v", err) return } // send and set env - if err = sendAndSetEnv(args, session); err != nil { + if err = sendAndSetEnv(args, ss.session); err != nil { return } // session input and output - serverIn, err = session.StdinPipe() + ss.serverIn, err = ss.session.StdinPipe() if err != nil { err = fmt.Errorf("stdin pipe failed: %v", err) return } - serverOut, err = session.StdoutPipe() + ss.serverOut, err = ss.session.StdoutPipe() if err != nil { err = fmt.Errorf("stdout pipe failed: %v", err) return } - serverErr, err = session.StderrPipe() + ss.serverErr, err = ss.session.StderrPipe() if err != nil { err = fmt.Errorf("stderr pipe failed: %v", err) return @@ -1114,11 +1225,11 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session // ssh agent forward if !control { - sshAgentForward(args, client, session) + sshAgentForward(args, param, ss.client, ss.session) } // not terminal or not tty - if !isTerminal || !tty { + if !isTerminal || !ss.tty { return } @@ -1132,7 +1243,7 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session if term == "" { term = "xterm-256color" } - if err = session.RequestPty(term, height, width, ssh.TerminalModes{}); err != nil { + if err = ss.session.RequestPty(term, height, width, ssh.TerminalModes{}); err != nil { err = fmt.Errorf("request pty failed: %v", err) return } diff --git a/tssh/main.go b/tssh/main.go index 8ffb072..3290f87 100644 --- a/tssh/main.go +++ b/tssh/main.go @@ -117,62 +117,8 @@ func cleanupAfterLogin() { } } -func parseRemoteCommand(args *sshArgs) (string, error) { - command := args.Option.get("RemoteCommand") - if args.Command != "" && command != "" && strings.ToLower(command) != "none" { - return "", fmt.Errorf("cannot execute command-line and remote command") - } - if args.Command != "" { - if len(args.Argument) == 0 { - return args.Command, nil - } - return fmt.Sprintf("%s %s", args.Command, strings.Join(args.Argument, " ")), nil - } - if strings.ToLower(command) == "none" { - return "", nil - } else if command != "" { - return command, nil - } - return getConfig(args.Destination, "RemoteCommand"), nil -} - var isTerminal bool = isatty.IsTerminal(os.Stdin.Fd()) || isatty.IsCygwinTerminal(os.Stdin.Fd()) -func parseCmdAndTTY(args *sshArgs) (cmd string, tty bool, err error) { - cmd, err = parseRemoteCommand(args) - if err != nil { - return - } - - if args.DisableTTY && args.ForceTTY { - err = fmt.Errorf("cannot specify -t with -T") - return - } - if args.DisableTTY { - tty = false - return - } - if args.ForceTTY { - tty = true - return - } - - requestTTY := getConfig(args.Destination, "RequestTTY") - switch strings.ToLower(requestTTY) { - case "", "auto": - tty = isTerminal && (cmd == "") - case "no": - tty = false - case "force": - tty = true - case "yes": - tty = isTerminal - default: - err = fmt.Errorf("unknown RequestTTY option: %s", requestTTY) - } - return -} - func TsshMain() int { var args sshArgs parser := arg.MustParse(&args) @@ -252,26 +198,17 @@ func TsshMain() int { } func sshStart(args *sshArgs) error { - // parse cmd and tty - command, tty, err := parseCmdAndTTY(args) - if err != nil { - return err - } - // ssh login - client, session, serverIn, serverOut, serverErr, err := sshLogin(args, tty) + ss, err := sshLogin(args) if err != nil { return err } - defer client.Close() - if session != nil { - defer session.Close() - } + defer ss.Close() // stdio forward if args.StdioForward != "" { var wg *sync.WaitGroup - wg, err = stdioForward(client, args.StdioForward) + wg, err = stdioForward(ss.client, args.StdioForward) if err != nil { return err } @@ -283,7 +220,7 @@ func sshStart(args *sshArgs) error { // no command if args.NoCommand { cleanupAfterLogin() - _ = client.Wait() + _ = ss.client.Wait() return nil } @@ -296,24 +233,24 @@ func sshStart(args *sshArgs) error { } // execute remote tools if necessary - execRemoteTools(args, client) + execRemoteTools(args, ss.client) // run command or start shell - if command != "" { - if err := session.Start(command); err != nil { - return fmt.Errorf("start command [%s] failed: %v", command, err) + if ss.cmd != "" { + if err := ss.session.Start(ss.cmd); err != nil { + return fmt.Errorf("start command [%s] failed: %v", ss.cmd, err) } } else { - if err := session.Shell(); err != nil { + if err := ss.session.Shell(); err != nil { return fmt.Errorf("start shell failed: %v", err) } } // execute expect interactions if necessary - serverOut, serverErr = execExpectInteractions(args, serverIn, serverOut, serverErr) + execExpectInteractions(args, ss) // make stdin raw - if isTerminal && tty { + if isTerminal && ss.tty { state, err := makeStdinRaw() if err != nil { return err @@ -322,15 +259,15 @@ func sshStart(args *sshArgs) error { } // enable trzsz - if err := enableTrzsz(args, client, session, serverIn, serverOut, serverErr, tty); err != nil { + if err := enableTrzsz(args, ss); err != nil { return err } // cleanup and wait for exit cleanupAfterLogin() - _ = session.Wait() + _ = ss.session.Wait() if args.Background { - _ = client.Wait() + _ = ss.client.Wait() } return nil } diff --git a/tssh/tokens.go b/tssh/tokens.go index de7b050..a837154 100644 --- a/tssh/tokens.go +++ b/tssh/tokens.go @@ -75,7 +75,10 @@ var getHostname = func() string { return hostname } -func expandTokens(str string, args *sshArgs, param *loginParam, tokens string) (string, error) { +func expandTokens(str string, args *sshArgs, param *sshParam, tokens string) (string, error) { + if !strings.ContainsRune(str, '%') { + return str, nil + } var buf strings.Builder state := byte(0) for _, c := range str { diff --git a/tssh/tokens_test.go b/tssh/tokens_test.go index 274bb68..451b77e 100644 --- a/tssh/tokens_test.go +++ b/tssh/tokens_test.go @@ -43,7 +43,7 @@ func TestExpandTokens(t *testing.T) { args := &sshArgs{ Destination: "dest", } - param := &loginParam{ + param := &sshParam{ host: "127.0.0.1", port: "1337", user: "penny", @@ -100,7 +100,7 @@ func TestInvalidHost(t *testing.T) { assertInvalidHost := func(host string) { t.Helper() - _, err := expandTokens("%h", &sshArgs{}, &loginParam{host: host}, "%hnpr") + _, err := expandTokens("%h", &sshArgs{}, &sshParam{host: host}, "%hnpr") require.NotNil(err) assert.Equal("hostname contains invalid characters", err.Error()) } @@ -139,7 +139,7 @@ func TestInvalidUser(t *testing.T) { assertInvalidUser := func(user string) { t.Helper() - _, err := expandTokens("%r", &sshArgs{}, &loginParam{user: user}, "%hnpr") + _, err := expandTokens("%r", &sshArgs{}, &sshParam{user: user}, "%hnpr") require.NotNil(err) assert.Equal("remote username contains invalid characters", err.Error()) } diff --git a/tssh/trzsz.go b/tssh/trzsz.go index f684c2d..5266e52 100644 --- a/tssh/trzsz.go +++ b/tssh/trzsz.go @@ -35,7 +35,6 @@ import ( "time" "github.com/trzsz/trzsz-go/trzsz" - "golang.org/x/crypto/ssh" ) func writeAll(dst io.Writer, data []byte) error { @@ -96,37 +95,36 @@ func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, serverErr io.Reader } } -func enableTrzsz(args *sshArgs, client *ssh.Client, session *ssh.Session, - serverIn io.WriteCloser, serverOut io.Reader, serverErr io.Reader, tty bool) error { +func enableTrzsz(args *sshArgs, ss *sshSession) error { // not terminal or not tty - if !isTerminal || !tty { - wrapStdIO(serverIn, serverOut, serverErr, tty) + if !isTerminal || !ss.tty { + wrapStdIO(ss.serverIn, ss.serverOut, ss.serverErr, ss.tty) return nil } // disable trzsz ( trz / tsz ) if strings.ToLower(getExOptionConfig(args, "EnableTrzsz")) == "no" { - wrapStdIO(serverIn, serverOut, serverErr, tty) - onTerminalResize(func(width, height int) { _ = session.WindowChange(height, width) }) + wrapStdIO(ss.serverIn, ss.serverOut, ss.serverErr, ss.tty) + onTerminalResize(func(width, height int) { _ = ss.session.WindowChange(height, width) }) return nil } // support trzsz ( trz / tsz ) - wrapStdIO(nil, nil, serverErr, tty) + wrapStdIO(nil, nil, ss.serverErr, ss.tty) trzsz.SetAffectedByWindows(false) if args.Relay || isNoGUI() { // run as a relay - trzszRelay := trzsz.NewTrzszRelay(os.Stdin, os.Stdout, serverIn, serverOut, trzsz.TrzszOptions{ + trzszRelay := trzsz.NewTrzszRelay(os.Stdin, os.Stdout, ss.serverIn, ss.serverOut, trzsz.TrzszOptions{ DetectTraceLog: args.TraceLog, }) // reset terminal size on resize - onTerminalResize(func(width, height int) { _ = session.WindowChange(height, width) }) + onTerminalResize(func(width, height int) { _ = ss.session.WindowChange(height, width) }) // setup tunnel connect trzszRelay.SetTunnelConnector(func(port int) net.Conn { - conn, _ := dialWithTimeout(client, "tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second) + conn, _ := dialWithTimeout(ss.client, "tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second) return conn }) return nil @@ -146,7 +144,7 @@ func enableTrzsz(args *sshArgs, client *ssh.Client, session *ssh.Session, // os.Stdout │ │ os.Stdout └─────────────┘ ServerOut │ │ // ◄───────────│ │◄──────────────────────────────────────────┤ │ // os.Stderr └────────┘ stderr └────────┘ - trzszFilter := trzsz.NewTrzszFilter(os.Stdin, os.Stdout, serverIn, serverOut, trzsz.TrzszOptions{ + trzszFilter := trzsz.NewTrzszFilter(os.Stdin, os.Stdout, ss.serverIn, ss.serverOut, trzsz.TrzszOptions{ TerminalColumns: int32(width), DetectDragFile: args.DragFile || strings.ToLower(getExOptionConfig(args, "EnableDragFile")) == "yes", DetectTraceLog: args.TraceLog, @@ -156,7 +154,7 @@ func enableTrzsz(args *sshArgs, client *ssh.Client, session *ssh.Session, // reset terminal size on resize onTerminalResize(func(width, height int) { trzszFilter.SetTerminalColumns(int32(width)) - _ = session.WindowChange(height, width) + _ = ss.session.WindowChange(height, width) }) // setup default paths @@ -165,7 +163,7 @@ func enableTrzsz(args *sshArgs, client *ssh.Client, session *ssh.Session, // setup tunnel connect trzszFilter.SetTunnelConnector(func(port int) net.Conn { - conn, _ := dialWithTimeout(client, "tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second) + conn, _ := dialWithTimeout(ss.client, "tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second) return conn })