-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathclient.go
108 lines (92 loc) · 2.73 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package jwks
import (
"context"
"fmt"
"time"
"golang.org/x/sync/semaphore"
"gopkg.in/square/go-jose.v2"
)
type JWKSClient interface {
GetKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error)
GetEncryptionKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error)
GetSignatureKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error)
}
type jWKSClient struct {
source JWKSSource
cache Cache
refresh time.Duration
sem *semaphore.Weighted
}
type cacheEntry struct {
jwk *jose.JSONWebKey
refresh int64
}
// Creates a new client with default cache implementation
func NewDefaultClient(source JWKSSource, refresh time.Duration, ttl time.Duration) JWKSClient {
if refresh >= ttl {
panic(fmt.Sprintf("invalid refresh: %v greater or eaquals to ttl: %v", refresh, ttl))
}
if refresh < 0 {
panic(fmt.Sprintf("invalid refresh: %v", refresh))
}
return NewClient(source, DefaultCache(ttl), refresh)
}
func NewClient(source JWKSSource, cache Cache, refresh time.Duration) JWKSClient {
return &jWKSClient{
source: source,
cache: cache,
refresh: refresh,
sem: semaphore.NewWeighted(1),
}
}
func (c *jWKSClient) GetSignatureKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) {
return c.GetKey(ctx, keyId, "sig")
}
func (c *jWKSClient) GetEncryptionKey(ctx context.Context, keyId string) (*jose.JSONWebKey, error) {
return c.GetKey(ctx, keyId, "enc")
}
func (c *jWKSClient) GetKey(ctx context.Context, keyId string, use string) (jwk *jose.JSONWebKey, err error) {
val, found := c.cache.Get(keyId)
if found {
entry := val.(*cacheEntry)
if time.Now().After(time.Unix(entry.refresh, 0)) && c.sem.TryAcquire(1) {
go func() {
defer c.sem.Release(1)
if _, err := c.refreshKey(ctx, keyId, use); err != nil {
logger.Printf("unable to refresh key: %v", err)
}
}()
}
return entry.jwk, nil
} else {
return c.refreshKey(ctx, keyId, use)
}
}
func (c *jWKSClient) refreshKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error) {
jwk, err := c.fetchJSONWebKey(ctx, keyId, use)
if err != nil {
return nil, err
}
c.save(keyId, jwk)
return jwk, nil
}
func (c *jWKSClient) save(keyId string, jwk *jose.JSONWebKey) {
c.cache.Set(keyId, &cacheEntry{
jwk: jwk,
refresh: time.Now().Add(c.refresh).Unix(),
})
}
func (c *jWKSClient) fetchJSONWebKey(ctx context.Context, keyId string, use string) (*jose.JSONWebKey, error) {
jsonWebKeySet, err := c.source.JSONWebKeySet(ctx)
if err != nil {
return nil, err
}
keys := jsonWebKeySet.Key(keyId)
if len(keys) == 0 {
return nil, fmt.Errorf("JWK is not found: %s", keyId)
}
for _, jwk := range keys {
return &jwk, nil
}
return nil, fmt.Errorf("JWK is not found %s, use: %s", keyId, use)
}