-
Notifications
You must be signed in to change notification settings - Fork 12
/
keygen_test.go
359 lines (327 loc) · 9.6 KB
/
keygen_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
package keygen
import (
"bytes"
"crypto/elliptic"
"fmt"
"io"
"os"
"path/filepath"
"testing"
)
func TestNewKeyPair(t *testing.T) {
kp, err := New("")
if err != nil {
t.Errorf("error creating SSH key pair: %v", err)
}
if kp.keyType != Ed25519 {
t.Errorf("expected default key type to be Ed25519, got %s", kp.keyType)
}
}
func nilTest(t testing.TB, kp *KeyPair) {
t.Helper()
if kp == nil {
t.Error("expected key pair to be non-nil")
}
if kp.PrivateKey() == nil {
t.Error("expected private key to be non-nil")
}
if kp.PublicKey() == nil {
t.Error("expected public key to be non-nil")
}
if kp.RawPrivateKey() == nil {
t.Error("expected raw private key to be non-nil")
}
if kp.RawProtectedPrivateKey() == nil {
t.Error("expected raw protected private key to be non-nil")
}
if kp.AuthorizedKey() == "" {
t.Error("expected authorized key to be non-nil")
}
if kp.Signer() == nil {
t.Error("expected signer to be non-nil")
}
}
func TestNilKeyPair(t *testing.T) {
for _, kt := range []KeyType{RSA, Ed25519, ECDSA} {
t.Run(fmt.Sprintf("test nil key pair for %s", kt), func(t *testing.T) {
kp, err := New("", WithKeyType(kt))
if err != nil {
t.Errorf("error creating SSH key pair: %v", err)
}
nilTest(t, kp)
})
}
}
func TestNilKeyPairWithPassphrase(t *testing.T) {
for _, kt := range []KeyType{RSA, Ed25519, ECDSA} {
t.Run(fmt.Sprintf("test nil key pair for %s", kt), func(t *testing.T) {
kp, err := New("", WithKeyType(kt), WithPassphrase("test"))
if err != nil {
t.Errorf("error creating SSH key pair: %v", err)
}
nilTest(t, kp)
})
}
}
func TestNilKeyPairTestdata(t *testing.T) {
for _, kt := range []KeyType{RSA, Ed25519, ECDSA} {
t.Run(fmt.Sprintf("test nil key pair for %s", kt), func(t *testing.T) {
kp, err := New(filepath.Join("testdata", "test_"+kt.String()), WithPassphrase("test"), WithKeyType(kt))
if err != nil {
t.Errorf("error creating SSH key pair: %v", err)
}
nilTest(t, kp)
})
}
}
func TestUnsupportedCurve(t *testing.T) {
_, err := New("", WithKeyType(ECDSA), WithEllipticCurve(elliptic.P224()))
if err == nil {
t.Error("expected error for unsupported curve")
}
_, err = New("", WithKeyType(ECDSA), WithEllipticCurve(elliptic.P256()))
if err != nil {
t.Errorf("expected no error for supported curve, got %v", err)
}
}
func TestGenerateEd25519Keys(t *testing.T) {
// Create temp directory for keys
dir := t.TempDir()
filename := "test"
k := &KeyPair{
path: filepath.Join(dir, filename),
keyType: Ed25519,
}
t.Run("test generate SSH keys", func(t *testing.T) {
err := k.generateEd25519Keys()
if err != nil {
t.Errorf("error creating SSH key pair: %v", err)
}
// TODO: is there a good way to validate these? Lengths seem to vary a bit,
// so far now we're just asserting that the keys indeed exist.
if len(k.RawPrivateKey()) == 0 {
t.Error("error creating SSH private key PEM; key is 0 bytes")
}
if len(k.AuthorizedKey()) == 0 {
t.Error("error creating SSH authorized key; key is 0 bytes")
}
})
t.Run("test write SSH keys", func(t *testing.T) {
k.path = filepath.Join(dir, "ssh1", filename)
if err := k.prepFilesystem(); err != nil {
t.Errorf("filesystem error: %v\n", err)
}
if err := k.WriteKeys(); err != nil {
t.Errorf("error writing SSH keys to %s: %v", k.path, err)
}
if testing.Verbose() {
t.Logf("Wrote keys to %s", k.path)
}
})
t.Run("test not overwriting existing keys", func(t *testing.T) {
k.path = filepath.Join(dir, "ssh2", filename)
if err := k.prepFilesystem(); err != nil {
t.Errorf("filesystem error: %v\n", err)
}
// Private key
if !createEmptyFile(t, k.privateKeyPath()) {
return
}
if err := k.WriteKeys(); err == nil {
t.Errorf("we wrote the private key over an existing file, but we were not supposed to")
}
if err := os.Remove(k.privateKeyPath()); err != nil {
t.Errorf("could not remove file %s", k.privateKeyPath())
}
// Public key
if !createEmptyFile(t, k.publicKeyPath()) {
return
}
if err := k.WriteKeys(); err == nil {
t.Errorf("we wrote the public key over an existing file, but we were not supposed to")
}
})
}
func TestGenerateECDSAKeys(t *testing.T) {
// Create temp directory for keys
dir := t.TempDir()
filename := "test"
k := &KeyPair{
path: filepath.Join(dir, filename),
keyType: ECDSA,
ec: elliptic.P384(),
}
t.Run("test generate SSH keys", func(t *testing.T) {
err := k.generateECDSAKeys(k.ec)
if err != nil {
t.Errorf("error creating SSH key pair: %v", err)
}
// TODO: is there a good way to validate these? Lengths seem to vary a bit,
// so far now we're just asserting that the keys indeed exist.
if len(k.RawPrivateKey()) == 0 {
t.Error("error creating SSH private key PEM; key is 0 bytes")
}
if len(k.AuthorizedKey()) == 0 {
t.Error("error creating SSH public key; key is 0 bytes")
}
})
t.Run("test write SSH keys", func(t *testing.T) {
k.path = filepath.Join(dir, "ssh1", filename)
if err := k.prepFilesystem(); err != nil {
t.Errorf("filesystem error: %v\n", err)
}
if err := k.WriteKeys(); err != nil {
t.Errorf("error writing SSH keys to %s: %v", k.path, err)
}
if testing.Verbose() {
t.Logf("Wrote keys to %s", k.path)
}
})
t.Run("test not overwriting existing keys", func(t *testing.T) {
k.path = filepath.Join(dir, "ssh2", filename)
if err := k.prepFilesystem(); err != nil {
t.Errorf("filesystem error: %v\n", err)
}
// Private key
if !createEmptyFile(t, k.privateKeyPath()) {
return
}
if err := k.WriteKeys(); err == nil {
t.Errorf("we wrote the private key over an existing file, but we were not supposed to")
}
if err := os.Remove(k.privateKeyPath()); err != nil {
t.Errorf("could not remove file %s", k.privateKeyPath())
}
// Public key
if !createEmptyFile(t, k.publicKeyPath()) {
return
}
if err := k.WriteKeys(); err == nil {
t.Errorf("we wrote the public key over an existing file, but we were not supposed to")
}
})
}
// touchTestFile is a utility function we're using in testing.
func createEmptyFile(t *testing.T, path string) (ok bool) {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0o700); err != nil {
t.Errorf("could not create directory %s: %v", dir, err)
return false
}
f, err := os.Create(path)
if err != nil {
t.Errorf("could not create file %s", path)
return false
}
if err := f.Close(); err != nil {
t.Errorf("could not close file: %v", err)
return false
}
if testing.Verbose() {
t.Logf("created dummy file at %s", path)
}
return true
}
func TestGeneratePublicKeyWithEmptyDir(t *testing.T) {
for _, keyType := range []KeyType{RSA, ECDSA, Ed25519} {
t.Run("test generate public key with empty dir", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "testkey")
k, err := New(fp, WithKeyType(keyType), WithWrite())
if err != nil {
t.Fatalf("error creating SSH key pair: %v", err)
}
f, err := os.Open(fp + ".pub")
if err != nil {
t.Fatalf("error opening SSH key file: %v", err)
}
defer f.Close()
fc, err := io.ReadAll(f)
if err != nil {
t.Fatalf("error reading SSH key file: %v", err)
}
if bytes.Equal(k.RawAuthorizedKey(), fc) {
t.Errorf("error key mismatch\nprivate key:\n%s\n\nactual file:\n%s", k.PrivateKey(), string(fc))
}
t.Cleanup(func() {
os.Remove(fp)
os.Remove(fp + ".pub")
})
})
}
}
func TestGenerateKeyWithPassphrase(t *testing.T) {
ph := "testpass"
for _, keyType := range []KeyType{RSA, ECDSA, Ed25519} {
t.Run("test generate key with passphrase", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "testph")
_, err := New(fp, WithKeyType(keyType), WithPassphrase(ph), WithWrite())
if err != nil {
t.Fatalf("error creating SSH key pair: %v", err)
}
f, err := os.Open(fp)
if err != nil {
t.Fatalf("error opening SSH key file: %v", err)
}
defer f.Close()
fc, err := io.ReadAll(f)
if err != nil {
t.Fatalf("error reading SSH key file: %v", err)
}
k, err := New(fp, WithKeyType(keyType), WithPassphrase(ph))
if err != nil {
t.Fatalf("error reading SSH key pair: %v", err)
}
if bytes.Equal(k.RawPrivateKey(), fc) {
t.Errorf("encrypted private key matches file contents")
}
t.Cleanup(func() {
os.Remove(fp)
os.Remove(fp + ".pub")
})
})
}
}
func TestReadingKeyWithPassphrase(t *testing.T) {
for _, keyType := range []KeyType{RSA, ECDSA, Ed25519} {
kp := filepath.Join("testdata", "test")
_, err := New(kp, WithKeyType(keyType), WithPassphrase("test"))
if err != nil {
t.Fatalf("error reading SSH key pair: %v", err)
}
}
}
func TestKeynameSuffix(t *testing.T) {
for _, keyType := range []KeyType{RSA, ECDSA, Ed25519} {
t.Run("test keyname suffix", func(t *testing.T) {
fp := filepath.Join(t.TempDir(), "testkey_"+keyType.String())
_, err := New(fp, WithKeyType(keyType), WithWrite())
if err != nil {
t.Fatalf("error creating SSH key pair: %v", err)
}
if _, err := os.Stat(fp); os.IsNotExist(err) {
t.Errorf("private key file %s does not exist", fp)
}
t.Cleanup(func() {
os.Remove(fp)
os.Remove(fp + ".pub")
})
})
}
}
func TestExpandPath(t *testing.T) {
tmpdir := t.TempDir()
os.Setenv("TEMP", tmpdir)
defer os.Unsetenv("TEMP")
// Test environment variable expansion
if fp := expandPath(filepath.Join("$TEMP", "testkey")); fp != filepath.Join(tmpdir, "testkey") {
t.Errorf("error expanding path: %s", fp)
}
// Test tilde expansion
homedir, err := os.UserHomeDir()
if err != nil {
t.Fatalf("error getting user home directory: %v", err)
}
if fp := expandPath(filepath.Join("~", "testkey")); fp != filepath.Join(homedir, "testkey") {
t.Errorf("error expanding path: %s", fp)
}
}