From f6d0e7325110a7d48c0ac10af7ee12ae7500dd86 Mon Sep 17 00:00:00 2001 From: Lonny Wong Date: Sat, 13 Jul 2024 08:19:28 +0800 Subject: [PATCH] expose tssh as a library --- README.md | 38 +++++++++ cmd/tssh/main.go | 2 +- tssh/agent.go | 4 +- tssh/config.go | 6 +- tssh/ctrl_unix.go | 2 +- tssh/ctrl_windows.go | 2 +- tssh/env.go | 2 +- tssh/forward.go | 14 ++-- tssh/login.go | 10 +-- tssh/main.go | 14 +++- tssh/ssh.go | 148 ++++++++++++++++++++++++++++++++++-- tssh/tools.go | 6 +- tssh/tools_install_trzsz.go | 20 ++--- tssh/udp.go | 2 +- 14 files changed, 224 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index 560a768..276a4f2 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,44 @@ trzsz-ssh ( tssh ) offers additional useful features: - Download from the [GitHub Releases](https://github.com/trzsz/trzsz-ssh/releases), unzip and add to `PATH` environment. +## Development + +The `github.com/trzsz/trzsz-ssh/tssh` can be used as a library, for example: + +```go +package main + +import ( + "log" + "os" + + "github.com/trzsz/trzsz-ssh/tssh" +) + +func main() { + // Example 1: execute command on remote server + client, err := tssh.SshLogin(&tssh.SshArgs{Destination: "root@192.168.0.1"}) + if err != nil { + log.Fatal(err) + } + defer client.Close() + session, err := client.NewSession() + if err != nil { + log.Fatal(err) + } + defer session.Close() + output, err := session.CombinedOutput("whoami") + if err != nil { + log.Fatal(err) + } + log.Printf("I'm %s", string(output)) + + // Example 2: run the tssh program + code := tssh.TsshMain([]string{"-t", "root@192.168.0.1", "bash -l"}) + os.Exit(code) +} +``` + ## Contributing Welcome and thank you for considering contributing. We appreciate all forms of support, from coding and testing to documentation and CI/CD improvements. diff --git a/cmd/tssh/main.go b/cmd/tssh/main.go index c11eecd..71ae1b2 100644 --- a/cmd/tssh/main.go +++ b/cmd/tssh/main.go @@ -31,5 +31,5 @@ import ( ) func main() { - os.Exit(tssh.TsshMain()) + os.Exit(tssh.TsshMain(os.Args[1:])) } diff --git a/tssh/agent.go b/tssh/agent.go index baed36c..d05fe97 100644 --- a/tssh/agent.go +++ b/tssh/agent.go @@ -88,7 +88,7 @@ func getAgentClient(args *sshArgs, param *sshParam) agent.ExtendedAgent { return agentClient } -func forwardToRemote(client sshClient, addr string) error { +func forwardToRemote(client SshClient, addr string) error { channels := client.HandleChannelOpen(kAgentChannelType) if channels == nil { return fmt.Errorf("agent: already have handler for %s", kAgentChannelType) @@ -122,7 +122,7 @@ func forwardAgentRequest(channel ssh.Channel, addr string) { forwardChannel(channel, conn) } -func requestAgentForwarding(session sshSession) error { +func requestAgentForwarding(session SshSession) error { ok, err := session.SendRequest(kAgentRequestName, true, nil) if err != nil { return err diff --git a/tssh/config.go b/tssh/config.go index b20b4a6..d3a293a 100644 --- a/tssh/config.go +++ b/tssh/config.go @@ -94,7 +94,7 @@ type tsshConfig struct { wildcardPatterns []*ssh_config.Pattern } -var userConfig = &tsshConfig{} +var userConfig *tsshConfig func parseTsshConfig() { path := filepath.Join(userHomeDir, ".tssh.conf") @@ -218,8 +218,8 @@ func showTsshConfig() { } } -func initUserConfig(configFile string) error { - var err error +func initUserConfig(configFile string) (err error) { + userConfig = &tsshConfig{} userHomeDir, err = os.UserHomeDir() if err != nil { debug("user home dir failed: %v", err) diff --git a/tssh/ctrl_unix.go b/tssh/ctrl_unix.go index 8309524..2f42e69 100644 --- a/tssh/ctrl_unix.go +++ b/tssh/ctrl_unix.go @@ -318,7 +318,7 @@ func startControlMaster(args *sshArgs, sshPath string) error { return nil } -func connectViaControl(args *sshArgs, param *sshParam) sshClient { +func connectViaControl(args *sshArgs, param *sshParam) SshClient { ctrlMaster := getOptionConfig(args, "ControlMaster") ctrlPath := getOptionConfig(args, "ControlPath") diff --git a/tssh/ctrl_windows.go b/tssh/ctrl_windows.go index 7171e90..7beb1b3 100644 --- a/tssh/ctrl_windows.go +++ b/tssh/ctrl_windows.go @@ -28,7 +28,7 @@ import ( "strings" ) -func connectViaControl(args *sshArgs, param *sshParam) sshClient { +func connectViaControl(args *sshArgs, param *sshParam) SshClient { ctrlMaster := getOptionConfig(args, "ControlMaster") ctrlPath := getOptionConfig(args, "ControlPath") diff --git a/tssh/env.go b/tssh/env.go index b07a4e2..8822f4f 100644 --- a/tssh/env.go +++ b/tssh/env.go @@ -118,7 +118,7 @@ func getSetEnvs(args *sshArgs) ([]*sshEnv, error) { return envs, nil } -func sendAndSetEnv(args *sshArgs, session sshSession) (string, error) { +func sendAndSetEnv(args *sshArgs, session SshSession) (string, error) { sendEnvs, err := getSendEnvs(args) if err != nil { return "", err diff --git a/tssh/forward.go b/tssh/forward.go index 3e74cfc..2f34728 100644 --- a/tssh/forward.go +++ b/tssh/forward.go @@ -227,7 +227,7 @@ func listenOnLocal(args *sshArgs, addr *string, port string) (listeners []net.Li return } -func listenOnRemote(args *sshArgs, client sshClient, addr *string, port string) (listeners []net.Listener) { +func listenOnRemote(args *sshArgs, client SshClient, addr *string, port string) (listeners []net.Listener) { listen := func(network, address string) { listener, err := client.Listen(network, address) if err != nil { @@ -251,7 +251,7 @@ func listenOnRemote(args *sshArgs, client sshClient, addr *string, port string) return } -func stdioForward(client sshClient, addr string) (*sync.WaitGroup, error) { +func stdioForward(client SshClient, addr string) (*sync.WaitGroup, error) { conn, err := client.DialTimeout("tcp", addr, 10*time.Second) if err != nil { return nil, fmt.Errorf("stdio forward failed: %v", err) @@ -283,7 +283,7 @@ func (d sshResolver) Resolve(ctx context.Context, name string) (context.Context, return ctx, []byte{}, nil } -func dynamicForward(client sshClient, b *bindCfg, args *sshArgs) { +func dynamicForward(client SshClient, b *bindCfg, args *sshArgs) { server, err := socks5.New(&socks5.Config{ Resolver: &sshResolver{}, Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -334,7 +334,7 @@ func netForward(local, remote net.Conn) { <-done } -func localForward(client sshClient, f *forwardCfg, args *sshArgs) { +func localForward(client SshClient, f *forwardCfg, args *sshArgs) { remoteAddr := joinHostPort(f.destHost, strconv.Itoa(f.destPort)) for _, listener := range listenOnLocal(args, f.bindAddr, strconv.Itoa(f.bindPort)) { go func(listener net.Listener) { @@ -360,7 +360,7 @@ func localForward(client sshClient, f *forwardCfg, args *sshArgs) { } } -func remoteForward(client sshClient, f *forwardCfg, args *sshArgs) { +func remoteForward(client SshClient, f *forwardCfg, args *sshArgs) { localAddr := joinHostPort(f.destHost, strconv.Itoa(f.destPort)) for _, listener := range listenOnRemote(args, client, f.bindAddr, strconv.Itoa(f.bindPort)) { go func(listener net.Listener) { @@ -386,7 +386,7 @@ func remoteForward(client sshClient, f *forwardCfg, args *sshArgs) { } } -func sshForward(client sshClient, args *sshArgs, param *sshParam) error { +func sshForward(client SshClient, args *sshArgs, param *sshParam) error { // clear all forwardings if strings.ToLower(getOptionConfig(args, "ClearAllForwardings")) == "yes" { return nil @@ -451,7 +451,7 @@ type x11Request struct { ScreenNumber uint32 } -func sshX11Forward(args *sshArgs, client sshClient, session sshSession) { +func sshX11Forward(args *sshArgs, client SshClient, session SshSession) { if args.NoX11Forward || !args.X11Untrusted && !args.X11Trusted && strings.ToLower(getOptionConfig(args, "ForwardX11")) != "yes" { return } diff --git a/tssh/login.go b/tssh/login.go index 19f469b..c557f94 100644 --- a/tssh/login.go +++ b/tssh/login.go @@ -992,7 +992,7 @@ func getNetworkAddressFamily(args *sshArgs) string { } } -func sshConnect(args *sshArgs, client sshClient, proxy string) (sshClient, *sshParam, bool, error) { +func sshConnect(args *sshArgs, client SshClient, proxy string) (SshClient, *sshParam, bool, error) { param, err := getSshParam(args) if err != nil { return nil, nil, false, err @@ -1028,7 +1028,7 @@ func sshConnect(args *sshArgs, client sshClient, proxy string) (sshClient, *sshP network := getNetworkAddressFamily(args) - proxyConnect := func(client sshClient, proxy string) (sshClient, *sshParam, bool, error) { + proxyConnect := func(client SshClient, proxy string) (SshClient, *sshParam, bool, error) { debug("login to [%s], addr: %s", args.Destination, param.addr) conn, err := client.DialTimeout(network, param.addr, 10*time.Second) if err != nil { @@ -1078,7 +1078,7 @@ func sshConnect(args *sshArgs, client sshClient, proxy string) (sshClient, *sshP } // has proxies - var proxyClient sshClient + var proxyClient SshClient for _, proxy = range param.proxy { proxyClient, _, _, err = sshConnect(&sshArgs{Destination: proxy}, proxyClient, proxy) if err != nil { @@ -1088,7 +1088,7 @@ func sshConnect(args *sshArgs, client sshClient, proxy string) (sshClient, *sshP return proxyConnect(proxyClient, proxy) } -func keepAlive(client sshClient, args *sshArgs) { +func keepAlive(client SshClient, args *sshArgs) { getOptionValue := func(option string) int { value, err := strconv.Atoi(getOptionConfig(args, option)) if err != nil { @@ -1129,7 +1129,7 @@ func keepAlive(client sshClient, args *sshArgs) { }() } -func sshAgentForward(args *sshArgs, param *sshParam, client sshClient, session sshSession) { +func sshAgentForward(args *sshArgs, param *sshParam, client SshClient, session SshSession) { if args.NoForwardAgent || !args.ForwardAgent && strings.ToLower(getOptionConfig(args, "ForwardAgent")) != "yes" { return } diff --git a/tssh/main.go b/tssh/main.go index 68f8d9b..294b444 100644 --- a/tssh/main.go +++ b/tssh/main.go @@ -116,9 +116,16 @@ func cleanupAfterLogin() { var isTerminal bool = isatty.IsTerminal(os.Stdin.Fd()) || isatty.IsCygwinTerminal(os.Stdin.Fd()) -func TsshMain() int { +// TrzMain is the main function of tssh program. +func TsshMain(argv []string) int { + // parse ssh args var args sshArgs - parser := arg.MustParse(&args) + parser, err := arg.NewParser(arg.Config{Out: os.Stderr, Exit: os.Exit}, &args) + if err != nil { + fmt.Fprintln(os.Stderr, err) + return -1 + } + parser.MustParse(argv) // debug log if args.Debug { @@ -129,7 +136,6 @@ func TsshMain() int { defer cleanupOnExit() // print message after stdin reset - var err error defer func() { if err != nil { fmt.Fprintf(os.Stderr, "%v\r\n", err) @@ -149,7 +155,7 @@ func TsshMain() int { } // execute local tools if necessary - if code, quit := execLocalTools(&args); quit { + if code, quit := execLocalTools(argv, &args); quit { return code } diff --git a/tssh/ssh.go b/tssh/ssh.go index 0131e90..b00be0c 100644 --- a/tssh/ssh.go +++ b/tssh/ssh.go @@ -28,6 +28,7 @@ import ( "fmt" "io" "net" + "strings" "time" "golang.org/x/crypto/ssh" @@ -40,36 +41,169 @@ const ( kAgentRequestName = "auth-agent-req@openssh.com" ) -type sshClient interface { +// SshClient implements a traditional SSH client that supports shells, +// subprocesses, TCP port/streamlocal forwarding and tunneled dialing. +type SshClient interface { + + // Wait blocks until the connection has shut down. Wait() error + + // Close closes the underlying network connection. Close() error - NewSession() (sshSession, error) + + // NewSession opens a new Session for this client. + NewSession() (SshSession, error) + + // DialTimeout initiates a connection to the addr from the remote host. DialTimeout(network, addr string, timeout time.Duration) (net.Conn, error) + + // Listen requests the remote peer open a listening socket on addr. Listen(network, addr string) (net.Listener, error) + + // HandleChannelOpen returns a channel on which NewChannel requests + // for the given type are sent. If the type already is being handled, + // nil is returned. The channel is closed when the connection is closed. HandleChannelOpen(channelType string) <-chan ssh.NewChannel + + // SendRequest sends a global request, and returns the reply. SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) } -type sshSession interface { +// SshSession represents a connection to a remote command or shell. +type SshSession interface { + + // Wait waits for the remote command to exit. Wait() error + + // Close closes the underlying network connection. Close() error + + // Shell starts a login shell on the remote host. Shell() error + + // Run runs cmd on the remote host. Run(cmd string) error + + // Start runs cmd on the remote host. Start(cmd string) error + + // WindowChange informs the remote host about a terminal window dimension + // change to height rows and width columns. WindowChange(height, width int) error + + // Setenv sets an environment variable that will be applied to any + // command executed by Shell or Run. Setenv(name, value string) error + + // StdinPipe returns a pipe that will be connected to the + // remote command's standard input when the command starts. StdinPipe() (io.WriteCloser, error) + + // StdoutPipe returns a pipe that will be connected to the + // remote command's standard output when the command starts. StdoutPipe() (io.Reader, error) + + // StderrPipe returns a pipe that will be connected to the + // remote command's standard error when the command starts. StderrPipe() (io.Reader, error) + + // Output runs cmd on the remote host and returns its standard output. Output(cmd string) ([]byte, error) + + // CombinedOutput runs cmd on the remote host and returns its combined + // standard output and standard error. CombinedOutput(cmd string) ([]byte, error) + + // RequestPty requests the association of a pty with the session on the remote host. RequestPty(term string, height, width int, termmodes ssh.TerminalModes) error + + // SendRequest sends an out-of-band channel request on the SSH channel + // underlying the session. SendRequest(name string, wantReply bool, payload []byte) (bool, error) } +// SshArgs specifies the arguments to log in to the remote server. +type SshArgs struct { + + // Destination specifies the remote server to log in to. + // e.g., alias in ~/.ssh/config, [user@]hostname[:port]. + Destination string + + // IPv4Only forces ssh to use IPv4 addresses only + IPv4Only bool + + // IPv6Only forces ssh to use IPv6 addresses only + IPv6Only bool + + // Port to connect to on the remote host + Port int + + // LoginName specifies the user to log in as on the remote machine + LoginName string + + // Identity selects the identity (private key) for public key authentication + Identity []string + + // CipherSpec specifies the cipher for encrypting the session + CipherSpec string + + // ConfigFile specifies the per-user configuration file + ConfigFile string + + // ProxyJump specifies the jump hosts separated by comma characters + ProxyJump string + + // Option gives options in the format used in the configuration file + Option map[string][]string + + // Debug causes ssh to print debugging messages about its progress + Debug bool + + // Udp means using UDP protocol ( QUIC / KCP ) connection like mosh + Udp bool + + // TsshdPath specifies the tsshd absolute path on the server + TsshdPath string +} + +// SshLogin logs in to the remote server and creates a Client. +func SshLogin(args *SshArgs) (SshClient, error) { + options := make(map[string][]string) + for key, values := range args.Option { + name := strings.ToLower(key) + if _, ok := options[name]; ok { + return nil, fmt.Errorf("option %s is repeated", name) + } + options[name] = values + } + if err := initUserConfig(args.ConfigFile); err != nil { + return nil, err + } + ss, err := sshLogin(&sshArgs{ + NoCommand: true, + Destination: args.Destination, + IPv4Only: args.IPv4Only, + IPv6Only: args.IPv6Only, + Port: args.Port, + LoginName: args.LoginName, + Identity: multiStr{args.Identity}, + CipherSpec: args.CipherSpec, + ConfigFile: args.ConfigFile, + ProxyJump: args.ProxyJump, + Option: sshOption{options}, + Debug: args.Debug, + Udp: args.Udp, + TsshdPath: args.TsshdPath, + }) + if err != nil { + return nil, err + } + return ss.client, nil +} + type sshClientSession struct { - client sshClient - session sshSession + client SshClient + session SshSession serverIn io.WriteCloser serverOut io.Reader serverErr io.Reader @@ -101,7 +235,7 @@ func (c *sshClientWrapper) Close() error { return c.client.Close() } -func (c *sshClientWrapper) NewSession() (sshSession, error) { +func (c *sshClientWrapper) NewSession() (SshSession, error) { return c.client.NewSession() } @@ -132,7 +266,7 @@ func (c *sshClientWrapper) SendRequest(name string, wantReply bool, payload []by return c.client.SendRequest(name, wantReply, payload) } -func sshNewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) sshClient { +func sshNewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) SshClient { client := ssh.NewClient(c, chans, reqs) return &sshClientWrapper{client} } diff --git a/tssh/tools.go b/tssh/tools.go index 6bee9de..c18a62f 100644 --- a/tssh/tools.go +++ b/tssh/tools.go @@ -448,14 +448,14 @@ func isFileNotExistOrEmpty(path string) bool { // // return true to quit with return code // return false to continue ssh login -func execLocalTools(args *sshArgs) (int, bool) { +func execLocalTools(argv []string, args *sshArgs) (int, bool) { switch { case args.Ver: fmt.Println(args.Version()) return 0, true case args.EncSecret: return execEncodeSecret() - case args.NewHost || len(os.Args) == 1 && isFileNotExistOrEmpty(userConfig.configPath): + case args.NewHost || len(argv) == 0 && isFileNotExistOrEmpty(userConfig.configPath): return execNewHost(args) default: return 0, false @@ -463,7 +463,7 @@ func execLocalTools(args *sshArgs) (int, bool) { } // execRemoteTools execute remote tools if necessary -func execRemoteTools(args *sshArgs, client sshClient) { +func execRemoteTools(args *sshArgs, client SshClient) { switch { case args.InstallTrzsz: execInstallTrzsz(args, client) diff --git a/tssh/tools_install_trzsz.go b/tssh/tools_install_trzsz.go index 9ca0966..67f9dda 100644 --- a/tssh/tools_install_trzsz.go +++ b/tssh/tools_install_trzsz.go @@ -70,7 +70,7 @@ func getLatestTrzszVersion() (string, error) { return release.TagName[1:], nil } -func checkTrzszVersion(client sshClient, cmd, name, version string) bool { +func checkTrzszVersion(client SshClient, cmd, name, version string) bool { session, err := client.NewSession() if err != nil { return false @@ -91,16 +91,16 @@ func pathJoin(path, name string) string { return fmt.Sprintf("%s/%s", path, name) } -func checkInstalledVersion(client sshClient, path, name, version string) bool { +func checkInstalledVersion(client SshClient, path, name, version string) bool { cmd := fmt.Sprintf("%s -v", pathJoin(path, name)) return checkTrzszVersion(client, cmd, name, version) } -func checkTrzszExecutable(client sshClient, name, version string) bool { +func checkTrzszExecutable(client SshClient, name, version string) bool { return checkTrzszVersion(client, fmt.Sprintf("$SHELL -l -c '%s -v'", name), name, version) } -func checkTrzszPathEnv(client sshClient, version, path string) { +func checkTrzszPathEnv(client SshClient, version, path string) { trzExecutable := checkTrzszExecutable(client, "trz", version) tszExecutable := checkTrzszExecutable(client, "tsz", version) if !trzExecutable || !tszExecutable { @@ -108,7 +108,7 @@ func checkTrzszPathEnv(client sshClient, version, path string) { } } -func getRemoteUserHome(client sshClient) (string, error) { +func getRemoteUserHome(client SshClient) (string, error) { session, err := client.NewSession() if err != nil { return "", err @@ -135,7 +135,7 @@ func getRemoteUserHome(client sshClient) (string, error) { return "~", nil } -func getRemoteServerOS(client sshClient) (string, error) { +func getRemoteServerOS(client SshClient) (string, error) { session, err := client.NewSession() if err != nil { return "", err @@ -156,7 +156,7 @@ func getRemoteServerOS(client sshClient) (string, error) { } } -func getRemoteServerArch(client sshClient) (string, error) { +func getRemoteServerArch(client SshClient) (string, error) { session, err := client.NewSession() if err != nil { return "", err @@ -183,7 +183,7 @@ func getRemoteServerArch(client sshClient) (string, error) { } } -func mkdirInstallPath(client sshClient, path string) error { +func mkdirInstallPath(client SshClient, path string) error { session, err := client.NewSession() if err != nil { return err @@ -306,7 +306,7 @@ func readTrzszBinary(path, version, svrOS, arch string) ([]byte, []byte, error) return extractTrzszBinary(gzr, version, svrOS, arch) } -func uploadTrzszBinary(client sshClient, path string, trz, tsz []byte) error { +func uploadTrzszBinary(client SshClient, path string, trz, tsz []byte) error { session, err := client.NewSession() if err != nil { return err @@ -403,7 +403,7 @@ func uploadTrzszBinary(client sshClient, path string, trz, tsz []byte) error { return nil } -func execInstallTrzsz(args *sshArgs, client sshClient) { +func execInstallTrzsz(args *sshArgs, client SshClient) { version := args.TrzszVersion if version == "" { var err error diff --git a/tssh/udp.go b/tssh/udp.go index dfec9b6..d938fed 100644 --- a/tssh/udp.go +++ b/tssh/udp.go @@ -127,7 +127,7 @@ func (c *sshUdpClient) Close() error { return c.client.Close() } -func (c *sshUdpClient) NewSession() (sshSession, error) { +func (c *sshUdpClient) NewSession() (SshSession, error) { stream, err := c.newStream("session") if err != nil { return nil, err