Skip to content

Commit 40a89c5

Browse files
committed
Initial re authentication implementation
Introduces the StreamingCredentialsProvider as the CredentialsProvider with the highest priority. TODO: needs to be tested
1 parent 847f1f9 commit 40a89c5

File tree

5 files changed

+111
-15
lines changed

5 files changed

+111
-15
lines changed

auth/auth.go

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ type StreamingCredentialsProvider interface {
99
// Subscribe subscribes to the credentials provider for updates.
1010
// It returns the current credentials, a cancel function to unsubscribe from the provider,
1111
// and an error if any.
12+
// TODO(ndyakov): Should we add context to the Subscribe method?
1213
Subscribe(listener CredentialsListener) (Credentials, CancelProviderFunc, error)
1314
}
1415

auth/reauth_credentials_listener.go

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package auth
2+
3+
// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
4+
// It is used to re-authenticate the credentials when they are updated.
5+
// It contains:
6+
// - reAuth: a function that takes the new credentials and returns an error if any.
7+
// - onErr: a function that takes an error and handles it.
8+
type ReAuthCredentialsListener struct {
9+
reAuth func(credentials Credentials) error
10+
onErr func(err error)
11+
}
12+
13+
// OnNext is called when the credentials are updated.
14+
// It calls the reAuth function with the new credentials.
15+
// If the reAuth function returns an error, it calls the onErr function with the error.
16+
func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) {
17+
if c.reAuth != nil {
18+
err := c.reAuth(credentials)
19+
if err != nil {
20+
if c.onErr != nil {
21+
c.onErr(err)
22+
}
23+
}
24+
}
25+
}
26+
27+
// OnError is called when an error occurs.
28+
// It can be called from both the credentials provider and the reAuth function.
29+
func (c *ReAuthCredentialsListener) OnError(err error) {
30+
if c.onErr != nil {
31+
c.onErr(err)
32+
}
33+
}
34+
35+
// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener.
36+
// Implements the auth.CredentialsListener interface.
37+
func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener {
38+
return &ReAuthCredentialsListener{
39+
reAuth: reAuth,
40+
onErr: onErr,
41+
}
42+
}
43+
44+
// Ensure ReAuthCredentialsListener implements the CredentialsListener interface.
45+
var _ CredentialsListener = (*ReAuthCredentialsListener)(nil)

internal_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,10 @@ func TestRingShardsCleanup(t *testing.T) {
212212
},
213213
NewClient: func(opt *Options) *Client {
214214
c := NewClient(opt)
215-
c.baseClient.onClose = func() error {
215+
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
216216
closeCounter.increment(opt.Addr)
217217
return nil
218-
}
218+
})
219219
return c
220220
},
221221
})
@@ -261,10 +261,10 @@ func TestRingShardsCleanup(t *testing.T) {
261261
}
262262
createCounter.increment(opt.Addr)
263263
c := NewClient(opt)
264-
c.baseClient.onClose = func() error {
264+
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
265265
closeCounter.increment(opt.Addr)
266266
return nil
267-
}
267+
})
268268
return c
269269
},
270270
})

redis.go

+60-10
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,57 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
283283
return cn, nil
284284
}
285285

286-
func (c *baseClient) reAuth(ctx context.Context, cn *Conn, credentials auth.Credentials) error {
287-
var err error
288-
username, password := credentials.BasicAuth()
289-
if username != "" {
290-
err = cn.AuthACL(ctx, username, password).Err()
291-
} else {
292-
err = cn.Auth(ctx, password).Err()
286+
func (c *baseClient) newReAuthCredentialsListener(ctx context.Context, conn *Conn) auth.CredentialsListener {
287+
return auth.NewReAuthCredentialsListener(
288+
c.reAuthConnection(c.context(ctx), conn),
289+
c.onAuthenticationErr(c.context(ctx), conn),
290+
)
291+
}
292+
293+
func (c *baseClient) reAuthConnection(ctx context.Context, cn *Conn) func(credentials auth.Credentials) error {
294+
return func(credentials auth.Credentials) error {
295+
var err error
296+
username, password := credentials.BasicAuth()
297+
if username != "" {
298+
err = cn.AuthACL(ctx, username, password).Err()
299+
} else {
300+
err = cn.Auth(ctx, password).Err()
301+
}
302+
return err
303+
}
304+
}
305+
func (c *baseClient) onAuthenticationErr(ctx context.Context, cn *Conn) func(err error) {
306+
return func(err error) {
307+
// since the connection pool of the *Conn will actually return us the underlying pool.Conn,
308+
// we can get it from the *Conn and remove it from the clients pool.
309+
if err != nil {
310+
if isBadConn(err, false, c.opt.Addr) {
311+
poolCn, _ := cn.connPool.Get(ctx)
312+
c.connPool.Remove(ctx, poolCn, err)
313+
}
314+
}
315+
}
316+
}
317+
318+
func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
319+
onClose := c.onClose
320+
return func() error {
321+
var firstErr error
322+
err := newOnClose()
323+
// Even if we have an error we would like to execute the onClose hook
324+
// if it exists. We will return the first error that occurred.
325+
// This is to keep error handling consistent with the rest of the code.
326+
if err != nil {
327+
firstErr = err
328+
}
329+
if onClose != nil {
330+
err = onClose()
331+
if err != nil && firstErr == nil {
332+
firstErr = err
333+
}
334+
}
335+
return firstErr
293336
}
294-
return err
295337
}
296338

297339
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
@@ -312,7 +354,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
312354

313355
var authenticated bool
314356
username, password := c.opt.Username, c.opt.Password
315-
if c.opt.CredentialsProviderContext != nil {
357+
if c.opt.StreamingCredentialsProvider != nil {
358+
credentials, cancelCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
359+
Subscribe(c.newReAuthCredentialsListener(ctx, conn))
360+
if err != nil {
361+
return err
362+
}
363+
c.onClose = c.wrappedOnClose(cancelCredentialsProvider)
364+
username, password = credentials.BasicAuth()
365+
} else if c.opt.CredentialsProviderContext != nil {
316366
if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil {
317367
return err
318368
}
@@ -336,7 +386,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
336386
}
337387

338388
if !authenticated && password != "" {
339-
err = c.reAuth(ctx, conn, auth.NewBasicCredentials(username, password))
389+
err = c.reAuthConnection(ctx, conn)(auth.NewBasicCredentials(username, password))
340390
if err != nil {
341391
return err
342392
}

sentinel.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
257257

258258
connPool = newConnPool(opt, rdb.dialHook)
259259
rdb.connPool = connPool
260-
rdb.onClose = failover.Close
260+
rdb.onClose = rdb.wrappedOnClose(failover.Close)
261261

262262
failover.mu.Lock()
263263
failover.onFailover = func(ctx context.Context, addr string) {

0 commit comments

Comments
 (0)