From 770b234e21f1df7e40816b73ce5a869e5fd26b06 Mon Sep 17 00:00:00 2001 From: kpango Date: Thu, 9 Jan 2025 18:39:07 +0900 Subject: [PATCH] [BUGFIX] add Health Check for Range over gRPC Connection Loop Signed-off-by: kpango --- internal/net/grpc/client.go | 54 +++++++++++++-------- internal/net/grpc/client_test.go | 3 +- internal/net/grpc/pool/pool.go | 13 +++-- internal/net/grpc/pool/pool_test.go | 2 +- internal/test/mock/grpc/grpc_client_mock.go | 2 +- internal/test/mock/grpc_testify_mock.go | 2 +- pkg/gateway/lb/handler/grpc/aggregation.go | 4 ++ pkg/gateway/mirror/service/mirror.go | 2 +- 8 files changed, 51 insertions(+), 31 deletions(-) diff --git a/internal/net/grpc/client.go b/internal/net/grpc/client.go index 78fc51448b..fd154566cb 100644 --- a/internal/net/grpc/client.go +++ b/internal/net/grpc/client.go @@ -90,7 +90,7 @@ type Client interface { GetCallOption() []CallOption GetBackoff() backoff.Backoff SetDisableResolveDNSAddr(addr string, disabled bool) - ConnectedAddrs() []string + ConnectedAddrs(context.Context) []string Close(ctx context.Context) error } @@ -249,7 +249,7 @@ func (g *gRPCClient) StartConnectionMonitor(ctx context.Context) (<-chan error, return ctx.Err() case <-prTick.C: if g.enablePoolRebalance { - err = g.rangeConns(func(addr string, p pool.Conn) bool { + err = g.rangeConns(ctx, func(addr string, p pool.Conn) bool { // if addr or pool is nil or empty the registration of conns is invalid let's disconnect them if addr == "" || p == nil { disconnectTargets = append(disconnectTargets, addr) @@ -286,7 +286,7 @@ func (g *gRPCClient) StartConnectionMonitor(ctx context.Context) (<-chan error, }) } case <-hcTick.C: - err = g.rangeConns(func(addr string, p pool.Conn) bool { + err = g.rangeConns(ctx, func(addr string, p pool.Conn) bool { // if addr or pool is nil or empty the registration of conns is invalid let's disconnect them if addr == "" || p == nil { disconnectTargets = append(disconnectTargets, addr) @@ -415,7 +415,7 @@ func (g *gRPCClient) Range( if g.conns.Len() == 0 { return errors.ErrGRPCClientConnNotFound("*") } - err = g.rangeConns(func(addr string, p pool.Conn) bool { + err = g.rangeConns(ctx, func(addr string, p pool.Conn) bool { ssctx, sspan := trace.StartSpan(sctx, apiName+"/Client.Range/"+addr) defer func() { if sspan != nil { @@ -478,7 +478,7 @@ func (g *gRPCClient) RangeConcurrent( if g.conns.Len() == 0 { return errors.ErrGRPCClientConnNotFound("*") } - err = g.rangeConns(func(addr string, p pool.Conn) bool { + err = g.rangeConns(ctx, func(addr string, p pool.Conn) bool { eg.Go(safety.RecoverFunc(func() (err error) { ssctx, sspan := trace.StartSpan(egctx, apiName+"/Client.RangeConcurrent/"+addr) defer func() { @@ -565,7 +565,7 @@ func (g *gRPCClient) OrderedRange( return nil default: p, ok := g.conns.Load(addr) - if !ok || p == nil { + if !ok || p == nil || !p.IsHealthy(sctx) { g.crl.Store(addr, true) log.Warnf("gRPCClient.OrderedRange operation failed, gRPC connection pool for %s is invalid,\terror: %v", addr, errors.ErrGRPCClientConnNotFound(addr)) continue @@ -634,7 +634,7 @@ func (g *gRPCClient) OrderedRangeConcurrent( addr := order eg.Go(safety.RecoverFunc(func() (err error) { p, ok := g.conns.Load(addr) - if !ok || p == nil { + if !ok || p == nil || !p.IsHealthy(sctx) { g.crl.Store(addr, true) log.Warnf("gRPCClient.OrderedRangeConcurrent operation failed, gRPC connection pool for %s is invalid,\terror: %v", addr, errors.ErrGRPCClientConnNotFound(addr)) return nil @@ -701,7 +701,7 @@ func (g *gRPCClient) RoundRobin( } do := func() (data any, err error) { - cerr := g.rangeConns(func(addr string, p pool.Conn) bool { + cerr := g.rangeConns(ctx, func(addr string, p pool.Conn) bool { select { case <-ctx.Done(): err = ctx.Err() @@ -879,14 +879,14 @@ func (g *gRPCClient) connectWithBackoff( errors.Is(err, context.DeadlineExceeded) { return nil, false, err } - return nil, err != nil, err + return nil, p.IsHealthy(ctx), err } status.Log(st.Code(), err) switch st.Code() { case codes.Internal, codes.Unavailable, codes.ResourceExhausted: - return nil, err != nil, err + return nil, p.IsHealthy(ctx), err } return nil, false, err } @@ -1066,7 +1066,7 @@ func (g *gRPCClient) Disconnect(ctx context.Context, addr string) error { atomic.AddUint64(&g.clientCount, ^uint64(0)) if p != nil { log.Debugf("gRPC client connection pool addr = %s will disconnect soon...", addr) - return nil, p.Disconnect() + return nil, p.Disconnect(ctx) } return nil, nil }) @@ -1085,10 +1085,10 @@ func (g *gRPCClient) Disconnect(ctx context.Context, addr string) error { return nil } -func (g *gRPCClient) ConnectedAddrs() (addrs []string) { +func (g *gRPCClient) ConnectedAddrs(ctx context.Context) (addrs []string) { addrs = make([]string, 0, g.conns.Len()) - err := g.rangeConns(func(addr string, p pool.Conn) bool { - if p != nil && p.IsHealthy(context.Background()) { + err := g.rangeConns(ctx, func(addr string, p pool.Conn) bool { + if p != nil && p.IsHealthy(ctx) { addrs = append(addrs, addr) } return true @@ -1104,18 +1104,34 @@ func (g *gRPCClient) Close(ctx context.Context) (err error) { g.stopMonitor() } g.conns.Range(func(addr string, p pool.Conn) bool { - derr := g.Disconnect(ctx, addr) - if derr != nil && !errors.Is(derr, errors.ErrGRPCClientConnNotFound(addr)) { - err = errors.Join(err, derr) + select { + case <-ctx.Done(): + return false + default: + derr := g.Disconnect(ctx, addr) + if derr != nil && !errors.Is(derr, errors.ErrGRPCClientConnNotFound(addr)) { + err = errors.Join(err, derr) + } + return true } - return true }) return err } -func (g *gRPCClient) rangeConns(fn func(addr string, p pool.Conn) bool) error { +func (g *gRPCClient) rangeConns(ctx context.Context, fn func(addr string, p pool.Conn) bool) error { var cnt int g.conns.Range(func(addr string, p pool.Conn) bool { + if p == nil || !p.IsHealthy(ctx) { + pc, err := p.Connect(ctx) + if pc == nil || err != nil || !pc.IsHealthy(ctx) { + if pc != nil { + pc.Disconnect(ctx) + } + log.Debugf("Unhealthy connection detected for %s during gRPC Connection Range over Loop:\t%s", addr, p.String()) + return true + } + p = pc + } cnt++ return fn(addr, p) }) diff --git a/internal/net/grpc/client_test.go b/internal/net/grpc/client_test.go index 0d073ce858..935ae07681 100644 --- a/internal/net/grpc/client_test.go +++ b/internal/net/grpc/client_test.go @@ -3107,7 +3107,7 @@ package grpc // stopMonitor: test.fields.stopMonitor, // } // -// gotAddrs := g.ConnectedAddrs() +// gotAddrs := g.ConnectedAddrs(context.Background) // if err := checkFunc(test.want, gotAddrs); err != nil { // tt.Errorf("error = %v", err) // } @@ -3303,6 +3303,7 @@ package grpc // // func Test_gRPCClient_rangeConns(t *testing.T) { // type args struct { +// ctx context.Context // fn func(addr string, p pool.Conn) bool // } // type fields struct { diff --git a/internal/net/grpc/pool/pool.go b/internal/net/grpc/pool/pool.go index 1af0edb867..39e77a2c28 100644 --- a/internal/net/grpc/pool/pool.go +++ b/internal/net/grpc/pool/pool.go @@ -45,9 +45,9 @@ type ( type Conn interface { Connect(context.Context) (Conn, error) - Disconnect() error + Disconnect(context.Context) error Do(ctx context.Context, f func(*ClientConn) error) error - Get(ctx context.Context) (conn *ClientConn, ok bool) + Get(context.Context) (conn *ClientConn, ok bool) IsHealthy(context.Context) bool IsIPConn() bool Len() uint64 @@ -437,8 +437,7 @@ func (p *pool) singleTargetConnect(ctx context.Context) (c Conn, err error) { return p, nil } -func (p *pool) Disconnect() (err error) { - ctx := context.Background() +func (p *pool) Disconnect(ctx context.Context) (err error) { p.closing.Store(true) defer p.closing.Store(false) emap := make(map[string]error, p.len()) @@ -618,7 +617,7 @@ func (p *pool) getHealthyConn( if retry <= 0 || retry > math.MaxUint64-pl || pl <= 0 { if p.isIP { log.Warnf("failed to find gRPC IP connection pool for %s.\tlen(pool): %d,\tretried: %d,\tseems IP %s is unhealthy will going to disconnect...", p.addr, pl, cnt, p.addr) - if err := p.Disconnect(); err != nil { + if err := p.Disconnect(ctx); err != nil { log.Debugf("failed to disconnect gRPC IP direct connection for %s,\terr: %v", p.addr, err) } return 0, nil, false @@ -757,8 +756,8 @@ func (p *pool) String() (str string) { func (pc *poolConn) Close(ctx context.Context, delay time.Duration) error { tdelay := delay / 10 - if tdelay < time.Millisecond*200 { - tdelay = time.Millisecond * 200 + if tdelay < time.Millisecond*5 { + tdelay = time.Millisecond * 5 } else if tdelay > time.Minute { tdelay = time.Second * 5 } diff --git a/internal/net/grpc/pool/pool_test.go b/internal/net/grpc/pool/pool_test.go index 1cc4cccd43..b970fe9485 100644 --- a/internal/net/grpc/pool/pool_test.go +++ b/internal/net/grpc/pool/pool_test.go @@ -2344,7 +2344,7 @@ package pool // reconnectHash: test.fields.reconnectHash, // } // -// err := p.Disconnect() +// err := p.Disconnect(context.Background) // if err := checkFunc(test.want, err); err != nil { // tt.Errorf("error = %v", err) // } diff --git a/internal/test/mock/grpc/grpc_client_mock.go b/internal/test/mock/grpc/grpc_client_mock.go index 078a409d9d..89cf2f75d1 100644 --- a/internal/test/mock/grpc/grpc_client_mock.go +++ b/internal/test/mock/grpc/grpc_client_mock.go @@ -51,7 +51,7 @@ func (gc *GRPCClientMock) OrderedRangeConcurrent( } // ConnectedAddrs calls the ConnectedAddrsFunc object. -func (gc *GRPCClientMock) ConnectedAddrs() []string { +func (gc *GRPCClientMock) ConnectedAddrs(_ context.Context) []string { return gc.ConnectedAddrsFunc() } diff --git a/internal/test/mock/grpc_testify_mock.go b/internal/test/mock/grpc_testify_mock.go index 20e215d29f..ce55e6923f 100644 --- a/internal/test/mock/grpc_testify_mock.go +++ b/internal/test/mock/grpc_testify_mock.go @@ -199,7 +199,7 @@ func (c *ClientInternal) GetBackoff() backoff.Backoff { return v } -func (c *ClientInternal) ConnectedAddrs() []string { +func (c *ClientInternal) ConnectedAddrs(ctx context.Context) []string { args := c.Called() v, ok := args.Get(0).([]string) if !ok { diff --git a/pkg/gateway/lb/handler/grpc/aggregation.go b/pkg/gateway/lb/handler/grpc/aggregation.go index 24ad0925cd..0a5a04a235 100644 --- a/pkg/gateway/lb/handler/grpc/aggregation.go +++ b/pkg/gateway/lb/handler/grpc/aggregation.go @@ -101,6 +101,7 @@ func (s *server) aggregationSearch( target + " canceled: " + err.Error())...) sspan.SetStatus(trace.StatusError, err.Error()) } + log.Debug(err) return nil case errors.Is(err, context.DeadlineExceeded), errors.Is(err, errors.ErrRPCCallFailed(target, context.DeadlineExceeded)): @@ -112,6 +113,7 @@ func (s *server) aggregationSearch( target + " deadline_exceeded: " + err.Error())...) sspan.SetStatus(trace.StatusError, err.Error()) } + log.Debug(err) return nil default: st, msg, err := status.ParseError(err, codes.Unknown, "failed to parse search gRPC error response", @@ -168,6 +170,7 @@ func (s *server) aggregationSearch( target + " canceled: " + err.Error())...) sspan.SetStatus(trace.StatusError, err.Error()) } + log.Debug(err) return nil case errors.Is(err, context.DeadlineExceeded), errors.Is(err, errors.ErrRPCCallFailed(target, context.DeadlineExceeded)): @@ -179,6 +182,7 @@ func (s *server) aggregationSearch( target + " deadline_exceeded: " + err.Error())...) sspan.SetStatus(trace.StatusError, err.Error()) } + log.Debug(err) return nil default: st, msg, err := status.ParseError(err, codes.Unknown, "failed to parse search gRPC error response", diff --git a/pkg/gateway/mirror/service/mirror.go b/pkg/gateway/mirror/service/mirror.go index dbd76572d4..959c555248 100644 --- a/pkg/gateway/mirror/service/mirror.go +++ b/pkg/gateway/mirror/service/mirror.go @@ -160,7 +160,7 @@ func (m *mirr) Start(ctx context.Context) <-chan error { // skipcq: GO-R1005 } } } - log.Debugf("[mirror]: connected mirror gateway targets: %v", m.gateway.GRPCClient().ConnectedAddrs()) + log.Debugf("[mirror]: connected mirror gateway targets: %v", m.gateway.GRPCClient().ConnectedAddrs(ctx)) } } })