diff --git a/internal/net/grpc/client.go b/internal/net/grpc/client.go index 78fc51448b..fd154566cb 100644 --- a/internal/net/grpc/client.go +++ b/internal/net/grpc/client.go @@ -90,7 +90,7 @@ type Client interface { GetCallOption() []CallOption GetBackoff() backoff.Backoff SetDisableResolveDNSAddr(addr string, disabled bool) - ConnectedAddrs() []string + ConnectedAddrs(context.Context) []string Close(ctx context.Context) error } @@ -249,7 +249,7 @@ func (g *gRPCClient) StartConnectionMonitor(ctx context.Context) (<-chan error, return ctx.Err() case <-prTick.C: if g.enablePoolRebalance { - err = g.rangeConns(func(addr string, p pool.Conn) bool { + err = g.rangeConns(ctx, func(addr string, p pool.Conn) bool { // if addr or pool is nil or empty the registration of conns is invalid let's disconnect them if addr == "" || p == nil { disconnectTargets = append(disconnectTargets, addr) @@ -286,7 +286,7 @@ func (g *gRPCClient) StartConnectionMonitor(ctx context.Context) (<-chan error, }) } case <-hcTick.C: - err = g.rangeConns(func(addr string, p pool.Conn) bool { + err = g.rangeConns(ctx, func(addr string, p pool.Conn) bool { // if addr or pool is nil or empty the registration of conns is invalid let's disconnect them if addr == "" || p == nil { disconnectTargets = append(disconnectTargets, addr) @@ -415,7 +415,7 @@ func (g *gRPCClient) Range( if g.conns.Len() == 0 { return errors.ErrGRPCClientConnNotFound("*") } - err = g.rangeConns(func(addr string, p pool.Conn) bool { + err = g.rangeConns(ctx, func(addr string, p pool.Conn) bool { ssctx, sspan := trace.StartSpan(sctx, apiName+"/Client.Range/"+addr) defer func() { if sspan != nil { @@ -478,7 +478,7 @@ func (g *gRPCClient) RangeConcurrent( if g.conns.Len() == 0 { return errors.ErrGRPCClientConnNotFound("*") } - err = g.rangeConns(func(addr string, p pool.Conn) bool { + err = g.rangeConns(ctx, func(addr string, p pool.Conn) bool { eg.Go(safety.RecoverFunc(func() (err error) { ssctx, sspan := trace.StartSpan(egctx, apiName+"/Client.RangeConcurrent/"+addr) defer func() { @@ -565,7 +565,7 @@ func (g *gRPCClient) OrderedRange( return nil default: p, ok := g.conns.Load(addr) - if !ok || p == nil { + if !ok || p == nil || !p.IsHealthy(sctx) { g.crl.Store(addr, true) log.Warnf("gRPCClient.OrderedRange operation failed, gRPC connection pool for %s is invalid,\terror: %v", addr, errors.ErrGRPCClientConnNotFound(addr)) continue @@ -634,7 +634,7 @@ func (g *gRPCClient) OrderedRangeConcurrent( addr := order eg.Go(safety.RecoverFunc(func() (err error) { p, ok := g.conns.Load(addr) - if !ok || p == nil { + if !ok || p == nil || !p.IsHealthy(sctx) { g.crl.Store(addr, true) log.Warnf("gRPCClient.OrderedRangeConcurrent operation failed, gRPC connection pool for %s is invalid,\terror: %v", addr, errors.ErrGRPCClientConnNotFound(addr)) return nil @@ -701,7 +701,7 @@ func (g *gRPCClient) RoundRobin( } do := func() (data any, err error) { - cerr := g.rangeConns(func(addr string, p pool.Conn) bool { + cerr := g.rangeConns(ctx, func(addr string, p pool.Conn) bool { select { case <-ctx.Done(): err = ctx.Err() @@ -879,14 +879,14 @@ func (g *gRPCClient) connectWithBackoff( errors.Is(err, context.DeadlineExceeded) { return nil, false, err } - return nil, err != nil, err + return nil, p.IsHealthy(ctx), err } status.Log(st.Code(), err) switch st.Code() { case codes.Internal, codes.Unavailable, codes.ResourceExhausted: - return nil, err != nil, err + return nil, p.IsHealthy(ctx), err } return nil, false, err } @@ -1066,7 +1066,7 @@ func (g *gRPCClient) Disconnect(ctx context.Context, addr string) error { atomic.AddUint64(&g.clientCount, ^uint64(0)) if p != nil { log.Debugf("gRPC client connection pool addr = %s will disconnect soon...", addr) - return nil, p.Disconnect() + return nil, p.Disconnect(ctx) } return nil, nil }) @@ -1085,10 +1085,10 @@ func (g *gRPCClient) Disconnect(ctx context.Context, addr string) error { return nil } -func (g *gRPCClient) ConnectedAddrs() (addrs []string) { +func (g *gRPCClient) ConnectedAddrs(ctx context.Context) (addrs []string) { addrs = make([]string, 0, g.conns.Len()) - err := g.rangeConns(func(addr string, p pool.Conn) bool { - if p != nil && p.IsHealthy(context.Background()) { + err := g.rangeConns(ctx, func(addr string, p pool.Conn) bool { + if p != nil && p.IsHealthy(ctx) { addrs = append(addrs, addr) } return true @@ -1104,18 +1104,34 @@ func (g *gRPCClient) Close(ctx context.Context) (err error) { g.stopMonitor() } g.conns.Range(func(addr string, p pool.Conn) bool { - derr := g.Disconnect(ctx, addr) - if derr != nil && !errors.Is(derr, errors.ErrGRPCClientConnNotFound(addr)) { - err = errors.Join(err, derr) + select { + case <-ctx.Done(): + return false + default: + derr := g.Disconnect(ctx, addr) + if derr != nil && !errors.Is(derr, errors.ErrGRPCClientConnNotFound(addr)) { + err = errors.Join(err, derr) + } + return true } - return true }) return err } -func (g *gRPCClient) rangeConns(fn func(addr string, p pool.Conn) bool) error { +func (g *gRPCClient) rangeConns(ctx context.Context, fn func(addr string, p pool.Conn) bool) error { var cnt int g.conns.Range(func(addr string, p pool.Conn) bool { + if p == nil || !p.IsHealthy(ctx) { + pc, err := p.Connect(ctx) + if pc == nil || err != nil || !pc.IsHealthy(ctx) { + if pc != nil { + pc.Disconnect(ctx) + } + log.Debugf("Unhealthy connection detected for %s during gRPC Connection Range over Loop:\t%s", addr, p.String()) + return true + } + p = pc + } cnt++ return fn(addr, p) }) diff --git a/internal/net/grpc/client_test.go b/internal/net/grpc/client_test.go index 0d073ce858..935ae07681 100644 --- a/internal/net/grpc/client_test.go +++ b/internal/net/grpc/client_test.go @@ -3107,7 +3107,7 @@ package grpc // stopMonitor: test.fields.stopMonitor, // } // -// gotAddrs := g.ConnectedAddrs() +// gotAddrs := g.ConnectedAddrs(context.Background) // if err := checkFunc(test.want, gotAddrs); err != nil { // tt.Errorf("error = %v", err) // } @@ -3303,6 +3303,7 @@ package grpc // // func Test_gRPCClient_rangeConns(t *testing.T) { // type args struct { +// ctx context.Context // fn func(addr string, p pool.Conn) bool // } // type fields struct { diff --git a/internal/net/grpc/pool/pool.go b/internal/net/grpc/pool/pool.go index 1af0edb867..3d9583e4da 100644 --- a/internal/net/grpc/pool/pool.go +++ b/internal/net/grpc/pool/pool.go @@ -34,6 +34,7 @@ import ( "github.com/vdaas/vald/internal/strings" "github.com/vdaas/vald/internal/sync" "github.com/vdaas/vald/internal/sync/errgroup" + "github.com/vdaas/vald/internal/sync/singleflight" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" ) @@ -45,9 +46,9 @@ type ( type Conn interface { Connect(context.Context) (Conn, error) - Disconnect() error + Disconnect(context.Context) error Do(ctx context.Context, f func(*ClientConn) error) error - Get(ctx context.Context) (conn *ClientConn, ok bool) + Get(context.Context) (conn *ClientConn, ok bool) IsHealthy(context.Context) bool IsIPConn() bool Len() uint64 @@ -72,6 +73,7 @@ type pool struct { current atomic.Uint64 bo backoff.Backoff eg errgroup.Group + group singleflight.Group[Conn] dopts []DialOption dialTimeout time.Duration roccd time.Duration // reconnection old connection closing duration @@ -94,6 +96,8 @@ func New(ctx context.Context, opts ...Option) (c Conn, err error) { p.init(true) p.closing.Store(false) + p.group = singleflight.New[Conn]() + var ( isIPv4, isIPv6 bool port uint16 @@ -222,11 +226,11 @@ func (p *pool) store(idx int, pc *poolConn) { return } p.init(false) - p.pmu.RLock() + p.pmu.Lock() if p.pool != nil && p.Size() > uint64(idx) && len(p.pool) > idx { p.pool[idx].Store(pc) } - p.pmu.RUnlock() + p.pmu.Unlock() } func (p *pool) loop( @@ -335,7 +339,6 @@ func (p *pool) Connect(ctx context.Context) (c Conn, err error) { if p == nil || p.closing.Load() { return p, nil } - p.init(false) if p.isIP || !p.resolveDNS { @@ -437,8 +440,7 @@ func (p *pool) singleTargetConnect(ctx context.Context) (c Conn, err error) { return p, nil } -func (p *pool) Disconnect() (err error) { - ctx := context.Background() +func (p *pool) Disconnect(ctx context.Context) (err error) { p.closing.Store(true) defer p.closing.Store(false) emap := make(map[string]error, p.len()) @@ -572,10 +574,11 @@ func (p *pool) Do(ctx context.Context, f func(conn *ClientConn) error) (err erro if p == nil { return errors.ErrGRPCClientConnNotFound("*") } - idx, conn, ok := p.getHealthyConn(ctx, 0, p.Len()) - if !ok || conn == nil { + idx, pc, ok := p.getHealthyConn(ctx, 0, p.Len()) + if !ok || pc == nil || pc.conn == nil { return errors.ErrGRPCClientConnNotFound(p.addr) } + conn := pc.conn err = f(conn) if errors.Is(err, grpc.ErrClientConnClosing) { if conn != nil { @@ -583,29 +586,29 @@ func (p *pool) Do(ctx context.Context, f func(conn *ClientConn) error) (err erro log.Warnf("Failed to close connection: %v", cerr) } } - conn, err = p.dial(ctx, p.addr) - if err == nil && conn != nil && isHealthy(ctx, conn) { - p.store(idx, &poolConn{ - conn: conn, - addr: p.addr, - }) - if newErr := f(conn); newErr != nil { + rerr := p.refreshConn(ctx, idx, pc, p.addr) + if rerr == nil { + if newErr := f(p.load(idx).conn); newErr != nil { return errors.Join(err, newErr) } return nil } + err = errors.Join(err, rerr) } return err } func (p *pool) Get(ctx context.Context) (conn *ClientConn, ok bool) { - _, conn, ok = p.getHealthyConn(ctx, 0, p.Len()) - return conn, ok + _, pc, ok := p.getHealthyConn(ctx, 0, p.Len()) + if ok && pc != nil { + return pc.conn, true + } + return nil, false } func (p *pool) getHealthyConn( ctx context.Context, cnt, retry uint64, -) (idx int, conn *ClientConn, ok bool) { +) (idx int, pc *poolConn, ok bool) { if p == nil || p.closing.Load() { return 0, nil, false } @@ -618,7 +621,7 @@ func (p *pool) getHealthyConn( if retry <= 0 || retry > math.MaxUint64-pl || pl <= 0 { if p.isIP { log.Warnf("failed to find gRPC IP connection pool for %s.\tlen(pool): %d,\tretried: %d,\tseems IP %s is unhealthy will going to disconnect...", p.addr, pl, cnt, p.addr) - if err := p.Disconnect(); err != nil { + if err := p.Disconnect(ctx); err != nil { log.Debugf("failed to disconnect gRPC IP direct connection for %s,\terr: %v", p.addr, err) } return 0, nil, false @@ -626,16 +629,12 @@ func (p *pool) getHealthyConn( if pl > 0 { idx = int(p.current.Add(1) % pl) } - if pc := p.load(idx); pc != nil && isHealthy(ctx, pc.conn) { - return idx, pc.conn, true + if pc = p.load(idx); pc != nil && isHealthy(ctx, pc.conn) { + return idx, pc, true } - conn, err := p.dial(ctx, p.addr) - if err == nil && conn != nil && isHealthy(ctx, conn) { - p.store(idx, &poolConn{ - conn: conn, - addr: p.addr, - }) - return idx, conn, true + err := p.refreshConn(ctx, idx, pc, p.addr) + if err == nil { + return idx, p.load(idx), true } log.Warnf("failed to find gRPC connection pool for %s.\tlen(pool): %d,\tretried: %d,\terror: %v", p.addr, pl, cnt, err) return idx, nil, false @@ -643,8 +642,8 @@ func (p *pool) getHealthyConn( if pl > 0 { idx = int(p.current.Add(1) % pl) - if pc := p.load(idx); pc != nil && isHealthy(ctx, pc.conn) { - return idx, pc.conn, true + if pc = p.load(idx); pc != nil && isHealthy(ctx, pc.conn) { + return idx, pc, true } } retry-- @@ -757,8 +756,8 @@ func (p *pool) String() (str string) { func (pc *poolConn) Close(ctx context.Context, delay time.Duration) error { tdelay := delay / 10 - if tdelay < time.Millisecond*200 { - tdelay = time.Millisecond * 200 + if tdelay < time.Millisecond*5 { + tdelay = time.Millisecond * 5 } else if tdelay > time.Minute { tdelay = time.Second * 5 } diff --git a/internal/net/grpc/pool/pool_test.go b/internal/net/grpc/pool/pool_test.go index 1cc4cccd43..b970fe9485 100644 --- a/internal/net/grpc/pool/pool_test.go +++ b/internal/net/grpc/pool/pool_test.go @@ -2344,7 +2344,7 @@ package pool // reconnectHash: test.fields.reconnectHash, // } // -// err := p.Disconnect() +// err := p.Disconnect(context.Background) // if err := checkFunc(test.want, err); err != nil { // tt.Errorf("error = %v", err) // } diff --git a/internal/test/mock/grpc/grpc_client_mock.go b/internal/test/mock/grpc/grpc_client_mock.go index 078a409d9d..89cf2f75d1 100644 --- a/internal/test/mock/grpc/grpc_client_mock.go +++ b/internal/test/mock/grpc/grpc_client_mock.go @@ -51,7 +51,7 @@ func (gc *GRPCClientMock) OrderedRangeConcurrent( } // ConnectedAddrs calls the ConnectedAddrsFunc object. -func (gc *GRPCClientMock) ConnectedAddrs() []string { +func (gc *GRPCClientMock) ConnectedAddrs(_ context.Context) []string { return gc.ConnectedAddrsFunc() } diff --git a/internal/test/mock/grpc_testify_mock.go b/internal/test/mock/grpc_testify_mock.go index 20e215d29f..ce55e6923f 100644 --- a/internal/test/mock/grpc_testify_mock.go +++ b/internal/test/mock/grpc_testify_mock.go @@ -199,7 +199,7 @@ func (c *ClientInternal) GetBackoff() backoff.Backoff { return v } -func (c *ClientInternal) ConnectedAddrs() []string { +func (c *ClientInternal) ConnectedAddrs(ctx context.Context) []string { args := c.Called() v, ok := args.Get(0).([]string) if !ok { diff --git a/pkg/gateway/lb/handler/grpc/aggregation.go b/pkg/gateway/lb/handler/grpc/aggregation.go index 24ad0925cd..0a5a04a235 100644 --- a/pkg/gateway/lb/handler/grpc/aggregation.go +++ b/pkg/gateway/lb/handler/grpc/aggregation.go @@ -101,6 +101,7 @@ func (s *server) aggregationSearch( target + " canceled: " + err.Error())...) sspan.SetStatus(trace.StatusError, err.Error()) } + log.Debug(err) return nil case errors.Is(err, context.DeadlineExceeded), errors.Is(err, errors.ErrRPCCallFailed(target, context.DeadlineExceeded)): @@ -112,6 +113,7 @@ func (s *server) aggregationSearch( target + " deadline_exceeded: " + err.Error())...) sspan.SetStatus(trace.StatusError, err.Error()) } + log.Debug(err) return nil default: st, msg, err := status.ParseError(err, codes.Unknown, "failed to parse search gRPC error response", @@ -168,6 +170,7 @@ func (s *server) aggregationSearch( target + " canceled: " + err.Error())...) sspan.SetStatus(trace.StatusError, err.Error()) } + log.Debug(err) return nil case errors.Is(err, context.DeadlineExceeded), errors.Is(err, errors.ErrRPCCallFailed(target, context.DeadlineExceeded)): @@ -179,6 +182,7 @@ func (s *server) aggregationSearch( target + " deadline_exceeded: " + err.Error())...) sspan.SetStatus(trace.StatusError, err.Error()) } + log.Debug(err) return nil default: st, msg, err := status.ParseError(err, codes.Unknown, "failed to parse search gRPC error response", diff --git a/pkg/gateway/mirror/service/mirror.go b/pkg/gateway/mirror/service/mirror.go index dbd76572d4..959c555248 100644 --- a/pkg/gateway/mirror/service/mirror.go +++ b/pkg/gateway/mirror/service/mirror.go @@ -160,7 +160,7 @@ func (m *mirr) Start(ctx context.Context) <-chan error { // skipcq: GO-R1005 } } } - log.Debugf("[mirror]: connected mirror gateway targets: %v", m.gateway.GRPCClient().ConnectedAddrs()) + log.Debugf("[mirror]: connected mirror gateway targets: %v", m.gateway.GRPCClient().ConnectedAddrs(ctx)) } } })