-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshell.go
167 lines (133 loc) · 3.92 KB
/
shell.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
package easshy
import (
"bytes"
"context"
"errors"
"io"
"golang.org/x/crypto/ssh"
)
var (
// ErrBufferOverflow is returned when command output is greater than buffer.
ErrBufferOverflow = errors.New("buffer overflow")
// ErrUnsupportedPrompt is returned when remote shell's prompt doesn't end with space.
ErrUnsupportedPrompt = errors.New("unsupported prompt")
)
// shell is a wrapper over ssh.Session that allows to call multiple commands and read their output separately, preserving shell context.
type shell struct {
*ssh.Session
stdin io.WriteCloser // Standard input pipe.
stdout io.Reader // Standard output pipe.
promptSuffix string // Shell's default prompt parsed from `/bin/sh`.
readBuffer []byte // Read buffer for command output. Cant be greater than 1MB.
}
// newShell wraps over *ssh.Session to create *shell object.
func newShell(session *ssh.Session) (*shell, error) {
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
ierr := session.Close()
return nil, errors.Join(err, ierr)
}
stdout, err := session.StdoutPipe()
if err != nil {
ierr := session.Close()
return nil, errors.Join(err, ierr)
}
stdin, err := session.StdinPipe()
if err != nil {
ierr := session.Close()
return nil, errors.Join(err, ierr)
}
sshSession := &shell{
Session: session,
stdin: stdin,
stdout: stdout,
readBuffer: make([]byte, 32*1024),
}
return sshSession, nil
}
// start starts the remote shell and sets proper shell's prompt suffix.
func (this *shell) start(ctx context.Context) error {
if err := this.Session.Start("/bin/sh"); err != nil {
return err
}
// wait for prompt to show; we will save that prompt suffix for later usage
var buf [256]byte // should be sufficient for shell initial prompt
select {
case <-ctx.Done():
return ctx.Err()
default:
n, err := this.stdout.Read(buf[:])
if err != nil {
return err
}
if n < 2 || buf[n-1] != ' ' {
return ErrUnsupportedPrompt
}
this.promptSuffix = string(buf[n-2 : n])
}
return nil
}
// runOpts apply all functional options to shell.
func (this *shell) runOpts(ctx context.Context, opts ...Option) error {
for _, opt := range opts {
if err := opt(ctx, this); err != nil {
return err
}
}
return nil
}
// setContext set given shell context to current session.
func (this *shell) setContext(ctx context.Context, shCtx ShellContext) error {
opt := WithShellContext(shCtx)
return opt(ctx, this)
}
func (this *shell) write(cmd string) error {
_, err := this.stdin.Write([]byte(cmd + "\n"))
return err
}
func (this *shell) read(ctx context.Context) (string, error) {
// read until we reach shell's prompt or cancel
for t := 0; t < len(this.readBuffer); {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
n, err := this.stdout.Read(this.readBuffer[t:])
if err != nil {
return "", err
}
t += n
if err := this.growBuffer(); err != nil {
break
}
if string(this.readBuffer[t-2:t]) == this.promptSuffix {
// remove whole prompt, but keep new line
t = bytes.LastIndex(this.readBuffer[:t-2], []byte("\n")) + 1
return string(this.readBuffer[:t]), nil
}
}
}
return string(this.readBuffer), ErrBufferOverflow
}
// growBuffer will grow read buffer if its not sufficient to hold command output up to 1 MB.
func (this *shell) growBuffer() error {
// no need to grow yet
if len(this.readBuffer) < cap(this.readBuffer) {
return nil
}
// cannot grow more
if len(this.readBuffer) >= 1024*1024 {
return ErrBufferOverflow
}
buf := make([]byte, len(this.readBuffer), 2*cap(this.readBuffer))
copy(buf, this.readBuffer)
this.readBuffer = buf
return nil
}
// close sends `exit` command to current shell's session and then closes ssh.Session.
func (this *shell) close() error {
var errs []error
errs = append(errs, this.write("exit"))
errs = append(errs, this.Session.Wait())
this.Session, this.stdin, this.stdout = nil, nil, nil
return errors.Join(errs...)
}