Skip to content

Commit

Permalink
expand supported options
Browse files Browse the repository at this point in the history
  • Loading branch information
lonnywong committed Jan 27, 2024
1 parent f016236 commit 6c684a0
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 158 deletions.
24 changes: 16 additions & 8 deletions tssh/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,33 @@ var (
agentClient agent.ExtendedAgent
)

func getAgentAddr(args *sshArgs) string {
func getAgentAddr(args *sshArgs, param *sshParam) (string, error) {
if addr := getOptionConfig(args, "IdentityAgent"); addr != "" {
if strings.ToLower(addr) == "none" {
return ""
return "", nil
}
return addr
expandedAddr, err := expandTokens(addr, args, param, "%CdhikLlnpru")
if err != nil {
return "", fmt.Errorf("expand IdentityAgent [%s] failed: %v", addr, err)
}
return resolveHomeDir(expandedAddr), nil
}
if addr := os.Getenv("SSH_AUTH_SOCK"); addr != "" {
return addr
return resolveHomeDir(addr), nil
}
if addr := defaultAgentAddr; addr != "" && isFileExist(addr) {
return addr
return addr, nil
}
return ""
return "", nil
}

func getAgentClient(args *sshArgs) agent.ExtendedAgent {
func getAgentClient(args *sshArgs, param *sshParam) agent.ExtendedAgent {
agentOnce.Do(func() {
addr := resolveHomeDir(getAgentAddr(args))
addr, err := getAgentAddr(args, param)
if err != nil {
warning("get agent addr failed: %v", err)
return
}
if addr == "" {
debug("ssh agent address is not set")
return
Expand Down
4 changes: 2 additions & 2 deletions tssh/ctrl_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func startControlMaster(args *sshArgs) error {
return nil
}

func connectViaControl(args *sshArgs, param *loginParam) *ssh.Client {
func connectViaControl(args *sshArgs, param *sshParam) *ssh.Client {
ctrlMaster := getOptionConfig(args, "ControlMaster")
ctrlPath := getOptionConfig(args, "ControlPath")

Expand All @@ -322,7 +322,7 @@ func connectViaControl(args *sshArgs, param *loginParam) *ssh.Client {

socket, err := expandTokens(ctrlPath, args, param, "%CdhikLlnpru")
if err != nil {
warning("expand control socket [%s] failed: %v", socket, err)
warning("expand ControlPath [%s] failed: %v", socket, err)
return nil
}
socket = resolveHomeDir(socket)
Expand Down
2 changes: 1 addition & 1 deletion tssh/ctrl_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
"golang.org/x/crypto/ssh"
)

func connectViaControl(args *sshArgs, param *loginParam) *ssh.Client {
func connectViaControl(args *sshArgs, param *sshParam) *ssh.Client {
ctrlMaster := getOptionConfig(args, "ControlMaster")
ctrlPath := getOptionConfig(args, "ControlPath")

Expand Down
16 changes: 8 additions & 8 deletions tssh/expect.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,10 @@ func getExpectTimeout(args *sshArgs, prefix string) int {
return int(count)
}

func execExpectInteractions(args *sshArgs, serverIn io.Writer,
serverOut io.Reader, serverErr io.Reader) (io.Reader, io.Reader) {
func execExpectInteractions(args *sshArgs, ss *sshSession) {
expectCount := getExpectCount(args, "")
if expectCount <= 0 {
return serverOut, serverErr
return
}

outReader, outWriter := io.Pipe()
Expand All @@ -456,15 +455,16 @@ func execExpectInteractions(args *sshArgs, serverIn io.Writer,
out: make(chan []byte, 10),
err: make(chan []byte, 10),
}
go expect.wrapOutput(serverOut, outWriter, expect.out)
go expect.wrapOutput(serverErr, errWriter, expect.err)
go expect.wrapOutput(ss.serverOut, outWriter, expect.out)
go expect.wrapOutput(ss.serverErr, errWriter, expect.err)

expect.execInteractions(serverIn, expectCount)
expect.execInteractions(ss.serverIn, expectCount)

if ctx.Err() == context.DeadlineExceeded {
warning("expect timeout after %d seconds", expectTimeout)
_, _ = serverIn.Write([]byte("\r")) // enter for shell prompt if timeout
_, _ = ss.serverIn.Write([]byte("\r")) // enter for shell prompt if timeout
}

return outReader, errReader
ss.serverOut = outReader
ss.serverErr = errReader
}
16 changes: 13 additions & 3 deletions tssh/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ func remoteForward(client *ssh.Client, f *forwardCfg, args *sshArgs) {
}
}

func sshForward(client *ssh.Client, args *sshArgs) error {
func sshForward(client *ssh.Client, args *sshArgs, param *sshParam) error {
// clear all forwardings
if strings.ToLower(getOptionConfig(args, "ClearAllForwardings")) == "yes" {
return nil
Expand All @@ -406,7 +406,12 @@ func sshForward(client *ssh.Client, args *sshArgs) error {
localForward(client, f, args)
}
for _, s := range getAllOptionConfig(args, "LocalForward") {
f, err := parseForwardCfg(s)
es, err := expandTokens(s, args, param, "%CdhikLlnpru")
if err != nil {
warning("expand LocalForward [%s] failed: %v", s, err)
continue
}
f, err := parseForwardCfg(es)
if err != nil {
warning("local forward failed: %v", err)
continue
Expand All @@ -419,7 +424,12 @@ func sshForward(client *ssh.Client, args *sshArgs) error {
remoteForward(client, f, args)
}
for _, s := range getAllOptionConfig(args, "RemoteForward") {
f, err := parseForwardCfg(s)
es, err := expandTokens(s, args, param, "%CdhikLlnpru")
if err != nil {
warning("expand RemoteForward [%s] failed: %v", s, err)
continue
}
f, err := parseForwardCfg(es)
if err != nil {
warning("remote forward failed: %v", err)
continue
Expand Down
Loading

0 comments on commit 6c684a0

Please sign in to comment.