forked from gambol99/go-oidc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjwks.go
205 lines (171 loc) · 5.32 KB
/
jwks.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
package oidc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"sync"
"time"
"github.com/pquerna/cachecontrol"
jose "gopkg.in/square/go-jose.v2"
)
// keysExpiryDelta is the allowed clock skew between a client and the OpenID Connect
// server.
//
// When keys expire, they are valid for this amount of time after.
//
// If the keys have not expired, and an ID Token claims it was signed by a key not in
// the cache, if and only if the keys expire in this amount of time, the keys will be
// updated.
const keysExpiryDelta = 30 * time.Second
func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) *remoteKeySet {
if now == nil {
now = time.Now
}
return &remoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now}
}
type remoteKeySet struct {
jwksURL string
ctx context.Context
now func() time.Time
// guard all other fields
mu sync.Mutex
// inflight suppresses parallel execution of updateKeys and allows
// multiple goroutines to wait for its result.
inflight *inflight
// A set of cached keys and their expiry.
cachedKeys []jose.JSONWebKey
expiry time.Time
}
// inflight is used to wait on some in-flight request from multiple goroutines.
type inflight struct {
doneCh chan struct{}
keys []jose.JSONWebKey
err error
}
func newInflight() *inflight {
return &inflight{doneCh: make(chan struct{})}
}
// wait returns a channel that multiple goroutines can receive on. Once it returns
// a value, the inflight request is done and result() can be inspected.
func (i *inflight) wait() <-chan struct{} {
return i.doneCh
}
// done can only be called by a single goroutine. It records the result of the
// inflight request and signals other goroutines that the result is safe to
// inspect.
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
i.keys = keys
i.err = err
close(i.doneCh)
}
// result cannot be called until the wait() channel has returned a value.
func (i *inflight) result() ([]jose.JSONWebKey, error) {
return i.keys, i.err
}
func (r *remoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
// We don't support JWTs signed with multiple signatures.
keyID := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
break
}
keys, expiry := r.keysFromCache()
// Don't check expiry yet. This optimizes for when the provider is unavailable.
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(&key); err == nil {
return payload, nil
}
}
}
if !r.now().Add(keysExpiryDelta).After(expiry) {
// Keys haven't expired, don't refresh.
return nil, errors.New("failed to verify id token signature")
}
keys, err := r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("fetching keys %v", err)
}
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(&key); err == nil {
return payload, nil
}
}
}
return nil, errors.New("failed to verify id token signature")
}
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey, expiry time.Time) {
r.mu.Lock()
defer r.mu.Unlock()
return r.cachedKeys, r.expiry
}
// keysFromRemote syncs the key set from the remote set, records the values in the
// cache, and returns the key set.
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
// Need to lock to inspect the inflight request field.
r.mu.Lock()
// If there's not a current inflight request, create one.
if r.inflight == nil {
r.inflight = newInflight()
// This goroutine has exclusive ownership over the current inflight
// request. It releases the resource by nil'ing the inflight field
// once the goroutine is done.
go func() {
// Sync keys and finish inflight when that's done.
keys, expiry, err := r.updateKeys()
r.inflight.done(keys, err)
// Lock to update the keys and indicate that there is no longer an
// inflight request.
r.mu.Lock()
defer r.mu.Unlock()
if err == nil {
r.cachedKeys = keys
r.expiry = expiry
}
// Free inflight so a different request can run.
r.inflight = nil
}()
}
inflight := r.inflight
r.mu.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-inflight.wait():
return inflight.result()
}
}
func (r *remoteKeySet) updateKeys() ([]jose.JSONWebKey, time.Time, error) {
req, err := http.NewRequest("GET", r.jwksURL, nil)
if err != nil {
return nil, time.Time{}, fmt.Errorf("oidc: can't create request: %v", err)
}
resp, err := doRequest(r.ctx, req)
if err != nil {
return nil, time.Time{}, fmt.Errorf("oidc: get keys failed %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, time.Time{}, fmt.Errorf("oidc: read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body)
}
var keySet jose.JSONWebKeySet
if err := json.Unmarshal(body, &keySet); err != nil {
return nil, time.Time{}, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
}
// If the server doesn't provide cache control headers, assume the
// keys expire immediately.
expiry := r.now()
_, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{})
if err == nil && e.After(expiry) {
expiry = e
}
return keySet.Keys, expiry, nil
}