diff --git a/jwt/jwt.go b/jwt/jwt.go index f1fbb635..52c7e559 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "strings" + "sync" "time" "github.com/clerk/clerk-sdk-go/v2" @@ -103,10 +104,9 @@ func Verify(ctx context.Context, params *VerifyParams) (*clerk.SessionClaims, er return claims, nil } +// Retrieve the JSON web key for the provided id from the set. func getJWK(ctx context.Context, kid string) (*clerk.JSONWebKey, error) { - // TODO Avoid multiple requests by caching results for the same - // instance. - jwks, err := jwks.Get(ctx, &jwks.GetParams{}) + jwks, err := getJWKSWithCache(ctx) if err != nil { return nil, err } @@ -118,6 +118,44 @@ func getJWK(ctx context.Context, kid string) (*clerk.JSONWebKey, error) { return nil, fmt.Errorf("no jwk key found for kid %s", kid) } +// Returns the JSON web key set. Tries a cached value first, but if +// there's no value or the entry has expired, it will fetch the set +// from the API and cache the value. +func getJWKSWithCache(ctx context.Context) (*clerk.JSONWebKeySet, error) { + const cacheKey = "/v1/jwks" + var jwks *clerk.JSONWebKeySet + var err error + + // Try the cache first. Make sure we have a non-expired entry and + // that the value is a valid JWKS. + entry, ok := getCache().Get(cacheKey) + if ok && !entry.HasExpired() { + jwks, ok = entry.GetValue().(*clerk.JSONWebKeySet) + if !ok || jwks == nil || len(jwks.Keys) == 0 { + jwks, err = forceGetJWKS(ctx, cacheKey) + if err != nil { + return nil, err + } + } + } else { + jwks, err = forceGetJWKS(ctx, cacheKey) + if err != nil { + return nil, err + } + } + return jwks, err +} + +// Fetches the JSON web key set from the API and caches it. +func forceGetJWKS(ctx context.Context, cacheKey string) (*clerk.JSONWebKeySet, error) { + jwks, err := jwks.Get(ctx, &jwks.GetParams{}) + if err != nil { + return nil, err + } + getCache().Set(cacheKey, jwks, time.Now().UTC().Add(time.Hour)) + return jwks, nil +} + func isValidIssuer(iss string) bool { return strings.HasPrefix(iss, "https://clerk.") || strings.Contains(iss, ".clerk.accounts") @@ -154,3 +192,66 @@ func Decode(_ context.Context, params *DecodeParams) (*clerk.Claims, error) { Extra: extraClaims, }, nil } + +// Caching store. +type cache struct { + mu sync.RWMutex + entries map[string]*cacheEntry +} + +// Get returns the cache entry for the provided key, if one exists. +func (c *cache) Get(key string) (*cacheEntry, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + entry, ok := c.entries[key] + return entry, ok +} + +// Set adds a new entry with the provided value in the cache under +// the provided key. An expiration date will be set for the entry. +func (c *cache) Set(key string, value any, expiresAt time.Time) { + c.mu.Lock() + defer c.mu.Unlock() + c.entries[key] = &cacheEntry{ + value: value, + expiresAt: expiresAt, + } +} + +// A cache entry has a value and an expiration date. +type cacheEntry struct { + value any + expiresAt time.Time +} + +// HasExpired returns true if the cache entry's expiration date +// has passed. +func (entry *cacheEntry) HasExpired() bool { + if entry == nil { + return true + } + return entry.expiresAt.Before(time.Now()) +} + +// GetValue returns the cache entry's value. +func (entry *cacheEntry) GetValue() any { + if entry == nil { + return nil + } + return entry.value +} + +var cacheInit sync.Once + +// A "singleton" cache for the package. +var defaultCache *cache + +// Lazy initialize and return the default cache singleton. +func getCache() *cache { + cacheInit.Do(func() { + defaultCache = &cache{ + entries: map[string]*cacheEntry{}, + } + }) + return defaultCache +} diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go new file mode 100644 index 00000000..a2673819 --- /dev/null +++ b/jwt/jwt_test.go @@ -0,0 +1,75 @@ +package jwt + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/clerk/clerk-sdk-go/v2" + "github.com/clerk/clerk-sdk-go/v2/clerktest" + "github.com/stretchr/testify/require" +) + +func TestVerify_InvalidToken(t *testing.T) { + clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{ + HTTPClient: &http.Client{ + Transport: &clerktest.RoundTripper{}, + }, + })) + + ctx := context.Background() + _, err := Verify(ctx, &VerifyParams{ + Token: "this-is-not-a-token", + }) + require.Error(t, err) +} + +func TestVerify_Cache(t *testing.T) { + ctx := context.Background() + totalRequests := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && r.URL.Path == "/v1/jwks" { + totalRequests++ + } + _, err := w.Write([]byte(`{ + "keys": [{ + "use": "sig", + "kty": "RSA", + "kid": "ins_123", + "alg": "RS256", + "n": "9m1LJW0dgEuK8SnN1Oy4LY8vaWABVS-hBTMA--_4LN1PZlMS5B2RPL85WkXYlHb0KXOSVrFKZLwYP-a9l3MFlW2YrPVAIvYfqPyqY5fmSEf-2qfrwosIhB2NSHyNRBQQ8-BX1RO9rIXIqYDKxGqktqMvYJmEGClmijbmFyQb2hpHD5PDbAB_DZvpZTEzWcQBL2ytHehILkYfg-ZZRyt7O8h5Gdy1v_TUlg8iMvchHlAkrIAmXNQigZmX_lne91tW8t4KMNJRfmUyLVCLbPnwxlmXXcice-0tmFw0OkCOteNWBeRNctJ3AIreGMzaJOJ2HeSUmJoX8iRKLLT3fsURLw", + "e": "AQAB" + }] +}`)) + require.NoError(t, err) + })) + defer ts.Close() + + clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{ + HTTPClient: ts.Client(), + URL: clerk.String(ts.URL), + })) + + token := "eyJhbGciOiJSUzI1NiIsImNhdCI6ImNsX0I3ZDRQRDExMUFBQSIsImtpZCI6Imluc18yOWR6bUdmQ3JydzdSMDRaVFFZRDNKSTB5dkYiLCJ0eXAiOiJKV1QifQ.eyJhenAiOiJodHRwczovL2Rhc2hib2FyZC5wcm9kLmxjbGNsZXJrLmNvbSIsImV4cCI6MTcwNzMwMDMyMiwiaWF0IjoxNzA3MzAwMjYyLCJpc3MiOiJodHRwczovL2NsZXJrLnByb2QubGNsY2xlcmsuY29tIiwibmJmIjoxNzA3MzAwMjUyLCJvcmdzIjp7Im9yZ18ySUlwcVIxenFNeHJQQkhSazNzTDJOSnJUQkQiOiJvcmc6YWRtaW4iLCJvcmdfMllHMlNwd0IzWEJoNUo0ZXF5elFVb0dXMjVhIjoib3JnOmFkbWluIiwib3JnXzJhZzJ6bmgxWGFjTXI0dGRXYjZRbEZSQ2RuaiI6Im9yZzphZG1pbiIsIm9yZ18yYWlldHlXa3VFSEhaRmRSUTFvVjYzMnZWaFciOiJvcmc6YWRtaW4ifSwic2lkIjoic2Vzc18yYm84b2gyRnIyeTNueVoyRVZQYktBd2ZvaU0iLCJzdWIiOiJ1c2VyXzI5ZTBXTnp6M245V1Q5S001WlpJYTBVVjNDNyJ9.6GtQafMBYY3Ij3pKHOyBYKt76LoLeBC71QUY_ho3k5nb0FBSvV0upKFLPBvIXNuF7hH0FK2QqDcAmrhbzAI-2qF_Ynve8Xl4VZCRpbTuZI7uL-tVjCvMffEIH-BHtrZ-QcXhEmNFQNIPyZTu21242he7U6o4S8st_aLmukWQzj_4qir7o5_fmVhm7YkLa0gYG5SLjkr2czwem1VGFHEVEOrHjun-g6eMnDNMMMysIOkZFxeqiCnqpc4u1V7Z7jfoK0r_-Unp8mGGln5KWYMCQyp1l1SkGwugtxeWfSbE4eklKRmItGOdVftvTyG16kDGpzsb22AQGtg65Iygni4PHg" + // Providing a custom key will not trigger a request to fetch the + // key set. + _, _ = Verify(ctx, &VerifyParams{ + Token: token, + JWK: &clerk.JSONWebKey{}, + }) + require.Equal(t, 0, totalRequests) + + // Verify without providing a key. The method will trigger a request + // to fetch the key set. + _, _ = Verify(ctx, &VerifyParams{ + Token: token, + }) + require.Equal(t, 1, totalRequests) + // Verifying again won't trigger a request because the key set is + // cached. + _, _ = Verify(ctx, &VerifyParams{ + Token: token, + }) + require.Equal(t, 1, totalRequests) +}