-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathsshcmd.go
349 lines (307 loc) · 10.1 KB
/
sshcmd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
package sshcmd
// The SSHCmd plugin implements an SSH command executor step. Only PublicKey
// and Password authentication are supported. GSSAPI not supported yet.
//
// Warning: this plugin does not lock password and keys in memory, and does no
// safe erase in memory to avoid forensic attacks. If you need that, please
// submit a PR.
//
// Warning: commands are interpreted, so be careful with external input in the
// test step arguments.
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"regexp"
"strconv"
"time"
"github.com/facebookincubator/go-belt/tool/logger"
"github.com/kballard/go-shellquote"
"golang.org/x/crypto/ssh"
"github.com/linuxboot/contest/pkg/event"
"github.com/linuxboot/contest/pkg/event/testevent"
"github.com/linuxboot/contest/pkg/logging"
"github.com/linuxboot/contest/pkg/target"
"github.com/linuxboot/contest/pkg/test"
"github.com/linuxboot/contest/plugins/teststeps"
)
// Name is the name used to look this plugin up.
var Name = "SSHCmd"
// Events is used by the framework to determine which events this plugin will
// emit. Any emitted event that is not registered here will cause the plugin to
// fail.
var Events = []event.Name{}
const defaultSSHPort = 22
const defaultTimeoutParameter = "10m"
// SSHCmd is used to run arbitrary commands as test steps.
type SSHCmd struct {
Host *test.Param
Port *test.Param
User *test.Param
PrivateKeyFile *test.Param
Password *test.Param
Executable *test.Param
Args []test.Param
Expect *test.Param
Timeout *test.Param
SkipIfEmptyHost *test.Param
}
// Name returns the plugin name.
func (ts SSHCmd) Name() string {
return Name
}
// Run executes the cmd step.
func (ts *SSHCmd) Run(
ctx context.Context,
ch test.TestStepChannels,
ev testevent.Emitter,
stepsVars test.StepsVariables,
params test.TestStepParameters,
resumeState json.RawMessage,
) (json.RawMessage, error) {
log := logger.FromCtx(ctx)
// XXX: Dragons ahead! The target (%t) substitution, and function
// expression evaluations are done at run-time, so they may still fail
// despite passing at early validation time.
// If the function evaluations called in validateAndPopulate are not idempotent,
// the output of the function expressions may be different (e.g. with a call to a
// backend or a random pool of results)
// Function evaluation could be done at validation time, but target
// substitution cannot, because the targets are not known at that time.
if err := ts.validateAndPopulate(params); err != nil {
return nil, err
}
f := func(ctx context.Context, target *target.Target) error {
// apply filters and substitutions to user, host, private key, and command args
user, err := ts.User.Expand(target, stepsVars)
if err != nil {
return fmt.Errorf("cannot expand user parameter: %v", err)
}
host, err := ts.Host.Expand(target, stepsVars)
if err != nil {
return fmt.Errorf("cannot expand host parameter: %v", err)
}
if len(host) == 0 {
shouldSkip := false
if !ts.SkipIfEmptyHost.IsEmpty() {
var err error
shouldSkip, err = strconv.ParseBool(ts.SkipIfEmptyHost.String())
if err != nil {
return fmt.Errorf("cannot expand 'skip_if_empty_host' parameter value '%s': %w", ts.SkipIfEmptyHost, err)
}
}
if shouldSkip {
return nil
} else {
return fmt.Errorf("host value is empty")
}
}
portStr, err := ts.Port.Expand(target, stepsVars)
if err != nil {
return fmt.Errorf("cannot expand port parameter: %v", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("failed to convert port parameter to integer: %v", err)
}
timeoutStr, err := ts.Timeout.Expand(target, stepsVars)
if err != nil {
return fmt.Errorf("cannot expand timeout parameter %s: %v", timeoutStr, err)
}
timeout, err := time.ParseDuration(timeoutStr)
if err != nil {
return fmt.Errorf("cannot parse timeout paramter: %v", err)
}
timeTimeout := time.Now().Add(timeout)
// apply functions to the private key, if any
var signer ssh.Signer
privKeyFile, err := ts.PrivateKeyFile.Expand(target, stepsVars)
if err != nil {
return fmt.Errorf("cannot expand private key file parameter: %v", err)
}
if privKeyFile != "" {
key, err := os.ReadFile(privKeyFile)
if err != nil {
return fmt.Errorf("cannot read private key at %s: %v", ts.PrivateKeyFile, err)
}
signer, err = ssh.ParsePrivateKey(key)
if err != nil {
return fmt.Errorf("cannot parse private key: %v", err)
}
}
password, err := ts.Password.Expand(target, stepsVars)
if err != nil {
return fmt.Errorf("cannot expand password parameter: %v", err)
}
auth := []ssh.AuthMethod{}
if signer != nil {
auth = append(auth, ssh.PublicKeys(signer))
}
if password != "" {
auth = append(auth, ssh.Password(password))
}
config := ssh.ClientConfig{
User: user,
Auth: auth,
// TODO expose this in the plugin arguments
//HostKeyCallback: ssh.FixedHostKey(hostKey),
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
executable, err := ts.Executable.Expand(target, stepsVars)
if err != nil {
return fmt.Errorf("cannot expand executable parameter: %v", err)
}
// apply functions to the command args, if any
var args []string
for _, arg := range ts.Args {
earg, err := arg.Expand(target, stepsVars)
if err != nil {
return fmt.Errorf("cannot expand command argument '%s': %v", arg, err)
}
args = append(args, earg)
}
// connect to the host
addr := net.JoinHostPort(host, strconv.Itoa(port))
client, err := ssh.Dial("tcp", addr, &config)
if err != nil {
return fmt.Errorf("cannot connect to SSH server %s: %v", addr, err)
}
defer func() {
if err := client.Close(); err != nil {
logging.Warnf(ctx, "Failed to close SSH connection to %s: %v", addr, err)
}
}()
session, err := client.NewSession()
if err != nil {
return fmt.Errorf("cannot create SSH session to server %s: %v", addr, err)
}
defer func() {
if err := session.Close(); err != nil && err != io.EOF {
logging.Warnf(ctx, "Failed to close SSH session to %s: %v", addr, err)
}
}()
// run the remote command and catch stdout/stderr
var stdout, stderr bytes.Buffer
session.Stdout, session.Stderr = &stdout, &stderr
cmd := shellquote.Join(append([]string{executable}, args...)...)
log.Debugf("Running remote SSH command on %s: '%v'", addr, cmd)
errCh := make(chan error, 1)
go func() {
innerErr := session.Run(cmd)
errCh <- innerErr
}()
expect := ts.Expect.String()
re, err := regexp.Compile(expect)
keepAliveCnt := 0
if err != nil {
return fmt.Errorf("malformed expect parameter: Can not compile %s with %v", expect, err)
}
for {
select {
case err := <-errCh:
log.Infof("Stdout of command '%s' is '%s'", cmd, stdout.Bytes())
if err == nil {
// Execute expectations
if expect == "" {
logging.Warnf(ctx, "no expectations specified")
} else {
matches := re.FindAll(stdout.Bytes(), -1)
if len(matches) > 0 {
log.Infof("match for regex '%s' found", expect)
} else {
return fmt.Errorf("match for %s not found for target %v", expect, target)
}
}
} else {
logging.Warnf(ctx, "Stderr of command '%s' is '%s'", cmd, stderr.Bytes())
}
return err
case <-ctx.Done():
return session.Signal(ssh.SIGKILL)
case <-time.After(250 * time.Millisecond):
keepAliveCnt++
if expect != "" {
matches := re.FindAll(stdout.Bytes(), -1)
if len(matches) > 0 {
log.Infof("match for regex '%s' found", expect)
return nil
}
}
if time.Now().After(timeTimeout) {
return fmt.Errorf("timed out after %s", timeout)
}
// This is needed to keep the connection to the server alive
if keepAliveCnt%20 == 0 {
err = session.Signal(ssh.Signal("CONT"))
if err != nil {
log.Warnf("Unable to send CONT to ssh server: %v", err)
}
}
}
}
}
return teststeps.ForEachTarget(Name, ctx, ch, f)
}
func (ts *SSHCmd) validateAndPopulate(params test.TestStepParameters) error {
var err error
ts.Host = params.GetOne("host")
if ts.Host.IsEmpty() {
return errors.New("invalid or missing 'host' parameter, must be exactly one string")
}
if params.GetOne("port").IsEmpty() {
ts.Port = test.NewParam(strconv.Itoa(defaultSSHPort))
} else {
var port int64
port, err = params.GetInt("port")
if err != nil {
return fmt.Errorf("invalid 'port' parameter, not an integer: %v", err)
}
if port < 0 || port > 0xffff {
return fmt.Errorf("invalid 'port' parameter: not in range 0-65535")
}
}
ts.User = params.GetOne("user")
if ts.User.IsEmpty() {
return errors.New("invalid or missing 'user' parameter, must be exactly one string")
}
ts.PrivateKeyFile = params.GetOne("private_key_file")
// do not fail if key file is empty, in such case it won't be used
ts.PrivateKeyFile = params.GetOne("private_key_file")
// do not fail if password is empty, in such case it won't be used
ts.Password = params.GetOne("password")
ts.Executable = params.GetOne("executable")
if ts.Executable.IsEmpty() {
return errors.New("invalid or missing 'executable' parameter, must be exactly one string")
}
ts.Args = params.Get("args")
ts.Expect = params.GetOne("expect")
if params.GetOne("timeout").IsEmpty() {
ts.Timeout = test.NewParam(defaultTimeoutParameter)
} else {
ts.Timeout = params.GetOne("timeout")
}
ts.SkipIfEmptyHost = params.GetOne("skip_if_empty_host")
return nil
}
// ValidateParameters validates the parameters associated to the TestStep
func (ts *SSHCmd) ValidateParameters(ctx context.Context, params test.TestStepParameters) error {
logging.Debugf(ctx, "Params %+v", params)
return ts.validateAndPopulate(params)
}
// New initializes and returns a new SSHCmd test step.
func New() test.TestStep {
return &SSHCmd{}
}
// Load returns the name, factory and events which are needed to register the step.
func Load() (string, test.TestStepFactory, []event.Name) {
return Name, New, Events
}