@@ -283,15 +283,57 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
283
283
return cn , nil
284
284
}
285
285
286
- func (c * baseClient ) reAuth (ctx context.Context , cn * Conn , credentials auth.Credentials ) error {
287
- var err error
288
- username , password := credentials .BasicAuth ()
289
- if username != "" {
290
- err = cn .AuthACL (ctx , username , password ).Err ()
291
- } else {
292
- err = cn .Auth (ctx , password ).Err ()
286
+ func (c * baseClient ) newReAuthCredentialsListener (ctx context.Context , conn * Conn ) auth.CredentialsListener {
287
+ return auth .NewReAuthCredentialsListener (
288
+ c .reAuthConnection (c .context (ctx ), conn ),
289
+ c .onAuthenticationErr (c .context (ctx ), conn ),
290
+ )
291
+ }
292
+
293
+ func (c * baseClient ) reAuthConnection (ctx context.Context , cn * Conn ) func (credentials auth.Credentials ) error {
294
+ return func (credentials auth.Credentials ) error {
295
+ var err error
296
+ username , password := credentials .BasicAuth ()
297
+ if username != "" {
298
+ err = cn .AuthACL (ctx , username , password ).Err ()
299
+ } else {
300
+ err = cn .Auth (ctx , password ).Err ()
301
+ }
302
+ return err
303
+ }
304
+ }
305
+ func (c * baseClient ) onAuthenticationErr (ctx context.Context , cn * Conn ) func (err error ) {
306
+ return func (err error ) {
307
+ // since the connection pool of the *Conn will actually return us the underlying pool.Conn,
308
+ // we can get it from the *Conn and remove it from the clients pool.
309
+ if err != nil {
310
+ if isBadConn (err , false , c .opt .Addr ) {
311
+ poolCn , _ := cn .connPool .Get (ctx )
312
+ c .connPool .Remove (ctx , poolCn , err )
313
+ }
314
+ }
315
+ }
316
+ }
317
+
318
+ func (c * baseClient ) wrappedOnClose (newOnClose func () error ) func () error {
319
+ onClose := c .onClose
320
+ return func () error {
321
+ var firstErr error
322
+ err := newOnClose ()
323
+ // Even if we have an error we would like to execute the onClose hook
324
+ // if it exists. We will return the first error that occurred.
325
+ // This is to keep error handling consistent with the rest of the code.
326
+ if err != nil {
327
+ firstErr = err
328
+ }
329
+ if onClose != nil {
330
+ err = onClose ()
331
+ if err != nil && firstErr == nil {
332
+ firstErr = err
333
+ }
334
+ }
335
+ return firstErr
293
336
}
294
- return err
295
337
}
296
338
297
339
func (c * baseClient ) initConn (ctx context.Context , cn * pool.Conn ) error {
@@ -312,7 +354,15 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
312
354
313
355
var authenticated bool
314
356
username , password := c .opt .Username , c .opt .Password
315
- if c .opt .CredentialsProviderContext != nil {
357
+ if c .opt .StreamingCredentialsProvider != nil {
358
+ credentials , cancelCredentialsProvider , err := c .opt .StreamingCredentialsProvider .
359
+ Subscribe (c .newReAuthCredentialsListener (ctx , conn ))
360
+ if err != nil {
361
+ return err
362
+ }
363
+ c .onClose = c .wrappedOnClose (cancelCredentialsProvider )
364
+ username , password = credentials .BasicAuth ()
365
+ } else if c .opt .CredentialsProviderContext != nil {
316
366
if username , password , err = c .opt .CredentialsProviderContext (ctx ); err != nil {
317
367
return err
318
368
}
@@ -336,7 +386,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
336
386
}
337
387
338
388
if ! authenticated && password != "" {
339
- err = c .reAuth (ctx , conn , auth .NewBasicCredentials (username , password ))
389
+ err = c .reAuthConnection (ctx , conn )( auth .NewBasicCredentials (username , password ))
340
390
if err != nil {
341
391
return err
342
392
}
0 commit comments