Skip to content

Commit

Permalink
enable trzsz after shell
Browse files Browse the repository at this point in the history
  • Loading branch information
lonnywong committed Oct 28, 2023
1 parent ea32c9d commit cabadb3
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 141 deletions.
122 changes: 8 additions & 114 deletions tssh/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ package tssh

import (
"bufio"
"bytes"
"crypto/x509"
"encoding/hex"
"fmt"
Expand All @@ -37,14 +36,12 @@ import (
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"

"github.com/skeema/knownhosts"
"github.com/trzsz/trzsz-go/trzsz"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/term"
Expand Down Expand Up @@ -729,6 +726,7 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, e
if err != nil {
return nil, fmt.Errorf("proxy [%s] new conn [%s] failed: %v", proxy, param.addr, err)
}
debug("login to [%s] success", args.Destination)
return ssh.NewClient(ncc, chans, reqs), nil
}

Expand All @@ -748,6 +746,7 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, e
if err != nil {
return nil, fmt.Errorf("proxy command [%s] new conn [%s] failed: %v", cmd, param.addr, err)
}
debug("login to [%s] success", args.Destination)
return ssh.NewClient(ncc, chans, reqs), nil
}

Expand All @@ -762,6 +761,7 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, e
if err != nil {
return nil, fmt.Errorf("new conn [%s] failed: %v", param.addr, err)
}
debug("login to [%s] success", args.Destination)
return ssh.NewClient(ncc, chans, reqs), nil
}

Expand Down Expand Up @@ -812,45 +812,6 @@ func keepAlive(client *ssh.Client, args *sshArgs) {
}()
}

func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) {
win := runtime.GOOS == "windows"
forwardIO := func(reader io.Reader, writer io.WriteCloser, oldVal, newVal []byte) {
defer writer.Close()
buffer := make([]byte, 32*1024)
for {
n, err := reader.Read(buffer)
if n > 0 {
buf := buffer[:n]
if win && !tty {
buf = bytes.ReplaceAll(buf, oldVal, newVal)
}
w := 0
for w < len(buf) {
n, err := writer.Write(buf[w:])
if err != nil {
warning("wrap stdio write failed: %v", err)
return
}
w += n
}
}
if err == io.EOF {
if win && tty {
_, _ = writer.Write([]byte{0x1A}) // ctrl + z
continue
}
break
}
if err != nil {
warning("wrap stdio read failed: %v", err)
return
}
}
}
go forwardIO(os.Stdin, serverIn, []byte("\r\n"), []byte("\n"))
go forwardIO(serverOut, os.Stdout, []byte("\n"), []byte("\r\n"))
}

func sshAgentForward(args *sshArgs, client *ssh.Client, session *ssh.Session) {
agentClient := getAgentClient()
if agentClient == nil {
Expand All @@ -871,7 +832,7 @@ func sshAgentForward(args *sshArgs, client *ssh.Client, session *ssh.Session) {
debug("request ssh agent forwarding success")
}

func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session, err error) {
func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session, serverIn io.WriteCloser, serverOut io.Reader, err error) {
defer func() {
if err != nil {
if session != nil {
Expand Down Expand Up @@ -911,12 +872,12 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session
session.Stderr = os.Stderr

// session input and output
serverIn, err := session.StdinPipe()
serverIn, err = session.StdinPipe()
if err != nil {
err = fmt.Errorf("stdin pipe failed: %v", err)
return
}
serverOut, err := session.StdoutPipe()
serverOut, err = session.StdoutPipe()
if err != nil {
err = fmt.Errorf("stdout pipe failed: %v", err)
return
Expand All @@ -925,9 +886,8 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session
// ssh agent forward
sshAgentForward(args, client, session)

// no tty
if !tty {
wrapStdIO(serverIn, serverOut, tty)
// not terminal or not tty
if !isTerminal || !tty {
return
}

Expand All @@ -942,71 +902,5 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session
return
}

// make stdin raw
if isTerminal {
var state *stdinState
state, err = makeStdinRaw()
if err != nil {
return
}
onExitFuncs = append(onExitFuncs, func() {
resetStdin(state)
})
}

// disable trzsz ( trz / tsz )
if strings.ToLower(getExOptionConfig(args, "EnableTrzsz")) == "no" {
wrapStdIO(serverIn, serverOut, tty)
onTerminalResize(func(width, height int) { _ = session.WindowChange(height, width) })
return
}

// support trzsz ( trz / tsz )
trzsz.SetAffectedByWindows(false)
if args.Relay || isNoGUI() {
// run as a relay
trzszRelay := trzsz.NewTrzszRelay(os.Stdin, os.Stdout, serverIn, serverOut, trzsz.TrzszOptions{
DetectTraceLog: args.TraceLog,
})
// reset terminal size on resize
onTerminalResize(func(width, height int) { _ = session.WindowChange(height, width) })
// setup tunnel connect
trzszRelay.SetTunnelConnector(func(port int) net.Conn {
conn, _ := dialWithTimeout(client, "tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second)
return conn
})
} else {
// create a TrzszFilter to support trzsz ( trz / tsz )
//
// os.Stdin ┌────────┐ os.Stdin ┌─────────────┐ ServerIn ┌────────┐
// ───────────►│ ├─────────────►│ ├─────────────►│ │
// │ │ │ TrzszFilter │ │ │
// ◄───────────│ Client │◄─────────────┤ │◄─────────────┤ Server │
// os.Stdout │ │ os.Stdout └─────────────┘ ServerOut │ │
// ◄───────────│ │◄──────────────────────────────────────────┤ │
// os.Stderr └────────┘ stderr └────────┘
trzszFilter := trzsz.NewTrzszFilter(os.Stdin, os.Stdout, serverIn, serverOut, trzsz.TrzszOptions{
TerminalColumns: int32(width),
DetectDragFile: args.DragFile || strings.ToLower(getExOptionConfig(args, "EnableDragFile")) == "yes",
DetectTraceLog: args.TraceLog,
})

// reset terminal size on resize
onTerminalResize(func(width, height int) {
trzszFilter.SetTerminalColumns(int32(width))
_ = session.WindowChange(height, width)
})

// setup default paths
trzszFilter.SetDefaultUploadPath(userConfig.defaultUploadPath)
trzszFilter.SetDefaultDownloadPath(userConfig.defaultDownloadPath)

// setup tunnel connect
trzszFilter.SetTunnelConnector(func(port int) net.Conn {
conn, _ := dialWithTimeout(client, "tcp", fmt.Sprintf("127.0.0.1:%d", port), time.Second)
return conn
})
}

return
}
69 changes: 42 additions & 27 deletions tssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,24 @@ func TsshMain() int {
}
args.Destination = dest

// start ssh program
if err = sshStart(&args); err != nil {
return 6
}
return 0
}

func sshStart(args *sshArgs) error {
// parse cmd and tty
command, tty, err := parseCmdAndTTY(&args)
command, tty, err := parseCmdAndTTY(args)
if err != nil {
return 6
return err
}

// ssh login
client, session, err := sshLogin(&args, tty)
client, session, serverIn, serverOut, err := sshLogin(args, tty)
if err != nil {
return 7
return err
}
defer client.Close()
if session != nil {
Expand All @@ -265,48 +273,55 @@ func TsshMain() int {
var wg *sync.WaitGroup
wg, err = stdioForward(client, args.StdioForward)
if err != nil {
return 8
return err
}
cleanupForGC()
wg.Wait()
return 0
return nil
}

// ssh forward
if err = sshForward(client, &args); err != nil {
return 9
if err := sshForward(client, args); err != nil {
return err
}

// cleanup for GC
cleanupForGC()

// no command
if args.NoCommand {
if client.Wait() != nil {
return 10
}
return 0
cleanupForGC()
_ = client.Wait()
return nil
}

// run command or start shell
if command != "" {
if err = session.Start(command); err != nil {
err = fmt.Errorf("start command [%s] failed: %v", command, err)
return 11
if err := session.Start(command); err != nil {
return fmt.Errorf("start command [%s] failed: %v", command, err)
}
} else {
if err = session.Shell(); err != nil {
err = fmt.Errorf("start shell failed: %v", err)
return 12
if err := session.Shell(); err != nil {
return fmt.Errorf("start shell failed: %v", err)
}
}

// make stdin raw
if isTerminal && tty {
state, err := makeStdinRaw()
if err != nil {
return err
}
defer resetStdin(state)
}

// wait for exit
if session.Wait() != nil {
return 13
// enable trzsz
if err := enableTrzsz(args, client, session, serverIn, serverOut, tty); err != nil {
return err
}
if args.Background && client.Wait() != nil {
return 14

// cleanup and wait for exit
cleanupForGC()
_ = session.Wait()
if args.Background {
_ = client.Wait()
}
return 0
return nil
}
Loading

0 comments on commit cabadb3

Please sign in to comment.