-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate_test.go
157 lines (126 loc) · 5.45 KB
/
validate_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package jwt
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
jwtpkg "github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
)
// generateRSAKey creates a new RSA private key.
func generateRSAKey() (*rsa.PrivateKey, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to generate RSA key: %w", err)
}
return privateKey, nil
}
// generateTestJWT creates a JWT for testing.
func generateTestJWT(signingKey *rsa.PrivateKey, kid, audience string, expiry time.Time, method jwtpkg.SigningMethod) (string, error) {
token := jwtpkg.New(method)
token.Header["kid"] = kid
claims := token.Claims.(jwtpkg.MapClaims)
claims["aud"] = audience
claims["exp"] = expiry.Unix()
claims["iat"] = time.Now().Unix()
claims["nbf"] = time.Now().Unix()
tokenString, err := token.SignedString(signingKey)
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return tokenString, nil
}
// --- Test ---
func TestJWTMiddleware(t *testing.T) {
// Generate Keys
signingKey, err := generateRSAKey()
assert.NoError(t, err, "Failed to generate signing key")
invalidSigningKey, err := generateRSAKey()
assert.NoError(t, err, "Failed to generate invalid signing key")
// Setup Static JWKS
kid := "test-kid-12345"
nBytes := signingKey.N.Bytes()
nBase64URL := base64.RawURLEncoding.EncodeToString(nBytes)
eInt := signingKey.E
eBigInt := big.NewInt(int64(eInt))
eBytes := eBigInt.Bytes()
eBase64URL := base64.RawURLEncoding.EncodeToString(eBytes)
staticJWKS := &JWKS{
Keys: []JSONWebKey{
{
Kid: kid,
Kty: "RSA",
// X5c: []string{certX5C}, // Use the base64 encoded certificate DER
E: eBase64URL,
// N: "iQ745_U-vjkxPblaw6phBpe08fC42mpcrS4pcr15HiyZQyQV-BFcEVyLwPdsz3ulMRN7OB_UMfCcPBHqOjguejoab6hyJFVVMw_epP4a3SpQN9qaCbnqaSxgSGiqSq663g3TjsF_Wu1m9L41eNoF6Yvh5kULMd6lqjY0LPO5ZZxaQFLtIHahoJKMvYy1BTS0VYcNsXTjxkgUEL6Vc8GV5vaClbnY3VA2hLbXC1SGJWjVGdYXhkuck2tHr58u87MPEaQ33C6YfyISZKsdumF5bTCcIH75jjC3WbMVOLgWg5w0MSiHOFyI76Ihxbb0nRicEuao0WzO9AS7HJ7L24FHFQ",
N: nBase64URL,
},
},
}
// Set up a dummy JWKSFetcher with preset keys.
minimalFetcher := &JWKSFetcher{
jwks: staticJWKS,
mutex: &sync.RWMutex{},
}
// Setup Validator.
audience := "api://my-test-api"
validator := NewJWTValidator(minimalFetcher, []string{audience}, []string{jwtpkg.SigningMethodRS256.Name})
// Setup Middleware
jwtMiddleware := JWTMiddleware(validator)
// Test Handler
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
})
testHandler := jwtMiddleware(nextHandler) // Apply middleware
// --- Test Cases ---
t.Run("Valid JWT", func(t *testing.T) {
// Generate a valid JWT token to be used as auth header.
validToken, err := generateTestJWT(signingKey, kid, audience, time.Now().Add(time.Hour), jwtpkg.SigningMethodRS256)
assert.NoError(t, err)
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validToken))
recorder := httptest.NewRecorder()
testHandler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code, "Expected status OK for valid token")
assert.Equal(t, "OK", recorder.Body.String(), "Expected 'OK' body for valid token")
})
t.Run("Invalid JWT - Bad Signature", func(t *testing.T) {
// Generate a token signed with the WRONG key, but using the correct kid
invalidToken, err := generateTestJWT(invalidSigningKey, kid, audience, time.Now().Add(time.Hour), jwtpkg.SigningMethodRS256)
assert.NoError(t, err)
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", invalidToken))
recorder := httptest.NewRecorder()
testHandler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code, "Expected status Unauthorized for invalid signature")
assert.Contains(t, recorder.Body.String(), "failed to parse jwt token", "Expected parsing error message") // The library wraps signature errors
})
t.Run("Invalid JWT - Wrong Audience", func(t *testing.T) {
wrongAudToken, err := generateTestJWT(signingKey, kid, "api://wrong-audience", time.Now().Add(time.Hour), jwtpkg.SigningMethodRS256)
assert.NoError(t, err)
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", wrongAudToken))
recorder := httptest.NewRecorder()
testHandler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code, "Expected status Unauthorized for wrong audience")
assert.Contains(t, recorder.Body.String(), "invalid token", "Expected invalid token message for wrong audience")
})
t.Run("Invalid JWT - Expired", func(t *testing.T) {
// Generate an expired token
expiredToken, err := generateTestJWT(signingKey, kid, audience, time.Now().Add(-time.Hour), jwtpkg.SigningMethodRS256) // Expired 1 hour ago
assert.NoError(t, err)
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", expiredToken))
recorder := httptest.NewRecorder()
testHandler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code, "Expected status Unauthorized for expired token")
assert.Contains(t, recorder.Body.String(), "failed to parse jwt token", "Expected parsing error message for expired token") // jwt-go includes expiry check in Parse
})
}