diff --git a/benchmark_test.go b/benchmark_test.go index 56d0fe57..5a383761 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -33,10 +33,13 @@ func BenchmarkServer(b *testing.B) { require.NoError(b, err, "Error in conf.SetDefaults") b.Run("GetPeerRateLimit() with no batching", func(b *testing.B) { - client := guber.NewPeerClient(guber.PeerConfig{ + client, err := guber.NewPeerClient(guber.PeerConfig{ Info: cluster.GetRandomPeer(cluster.DataCenterNone), Behavior: conf.Behaviors, }) + if err != nil { + b.Errorf("Error building client: %s", err) + } b.ResetTimer() diff --git a/functional_test.go b/functional_test.go index 2d365b13..654342b7 100644 --- a/functional_test.go +++ b/functional_test.go @@ -36,7 +36,6 @@ import ( "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" json "google.golang.org/protobuf/encoding/protojson" @@ -1618,6 +1617,16 @@ func TestHealthCheck(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*15) defer cancel() require.NoError(t, cluster.Restart(ctx)) + + // wait for every peer instance to come back online + for _, peer := range cluster.GetPeers() { + peerClient, err := guber.DialV1Server(peer.GRPCAddress, nil) + require.NoError(t, err) + testutil.UntilPass(t, 10, clock.Millisecond*300, func(t testutil.TestingT) { + healthResp, err = peerClient.HealthCheck(context.Background(), &guber.HealthCheckReq{}) + assert.Equal(t, "healthy", healthResp.GetStatus()) + }) + } } func TestLeakyBucketDivBug(t *testing.T) { @@ -1723,9 +1732,10 @@ func TestGRPCGateway(t *testing.T) { func TestGetPeerRateLimits(t *testing.T) { ctx := context.Background() - peerClient := guber.NewPeerClient(guber.PeerConfig{ + peerClient, err := guber.NewPeerClient(guber.PeerConfig{ Info: cluster.GetRandomPeer(cluster.DataCenterNone), }) + require.NoError(t, err) t.Run("Stable rate check request order", func(t *testing.T) { // Ensure response order matches rate check request order. diff --git a/global.go b/global.go index b1f652ae..bd0c1e7c 100644 --- a/global.go +++ b/global.go @@ -20,6 +20,7 @@ import ( "context" "github.com/mailgun/holster/v4/syncutil" + "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" ) @@ -31,7 +32,7 @@ type globalManager struct { wg syncutil.WaitGroup conf BehaviorConfig log FieldLogger - instance *V1Instance // todo circular import? V1Instance also holds a reference to globalManager + instance *V1Instance // TODO circular import? V1Instance also holds a reference to globalManager metricGlobalSendDuration prometheus.Summary metricBroadcastDuration prometheus.Summary metricBroadcastCounter *prometheus.CounterVec @@ -249,8 +250,8 @@ func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string] cancel() if err != nil { - // Skip peers that are not in a ready state - if !IsNotReady(err) { + // Only log if it's an unknown error + if !errors.Is(err, context.Canceled) && errors.Is(err, context.DeadlineExceeded) { gm.log.WithError(err).Errorf("while broadcasting global updates to '%s'", peer.Info().GRPCAddress) } } @@ -260,6 +261,10 @@ func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string] fan.Wait() } +// Close stops all goroutines and shuts down all the peers. func (gm *globalManager) Close() { gm.wg.Stop() + for _, peer := range gm.instance.GetPeerList() { + _ = peer.Shutdown(context.Background()) + } } diff --git a/go.mod b/go.mod index af975b8c..93080b32 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( go.opentelemetry.io/otel/trace v1.21.0 go.uber.org/goleak v1.3.0 golang.org/x/net v0.18.0 + golang.org/x/sync v0.3.0 golang.org/x/time v0.3.0 google.golang.org/genproto/googleapis/api v0.0.0-20231016165738-49dd2c1f3d0b google.golang.org/grpc v1.59.0 diff --git a/go.sum b/go.sum index 6ada5946..fea9ef4c 100644 --- a/go.sum +++ b/go.sum @@ -580,6 +580,7 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/gubernator.go b/gubernator.go index f33fa48c..7ec9a96a 100644 --- a/gubernator.go +++ b/gubernator.go @@ -343,7 +343,7 @@ func (s *V1Instance) asyncRequest(ctx context.Context, req *AsyncReq) { // Make an RPC call to the peer that owns this rate limit r, err := req.Peer.GetPeerRateLimit(ctx, req.Req) if err != nil { - if IsNotReady(err) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { attempts++ metricBatchSendRetries.WithLabelValues(req.Req.Name).Inc() req.Peer, err = s.GetPeer(ctx, req.Key) @@ -528,7 +528,7 @@ func (s *V1Instance) HealthCheck(ctx context.Context, r *HealthCheckReq) (health localPeers := s.conf.LocalPicker.Peers() for _, peer := range localPeers { for _, errMsg := range peer.GetLastErr() { - err := fmt.Errorf("Error returned from local peer.GetLastErr: %s", errMsg) + err := fmt.Errorf("error returned from local peer.GetLastErr: %s", errMsg) span.RecordError(err) errs = append(errs, err.Error()) } @@ -538,7 +538,7 @@ func (s *V1Instance) HealthCheck(ctx context.Context, r *HealthCheckReq) (health regionPeers := s.conf.RegionPicker.Peers() for _, peer := range regionPeers { for _, errMsg := range peer.GetLastErr() { - err := fmt.Errorf("Error returned from region peer.GetLastErr: %s", errMsg) + err := fmt.Errorf("error returned from region peer.GetLastErr: %s", errMsg) span.RecordError(err) errs = append(errs, err.Error()) } @@ -586,7 +586,8 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ return resp, nil } -// SetPeers is called by the implementor to indicate the pool of peers has changed +// SetPeers replaces the peers and shuts down all the previous peers. +// TODO this should return an error if we failed to connect to any of the new peers func (s *V1Instance) SetPeers(peerInfo []PeerInfo) { localPicker := s.conf.LocalPicker.New() regionPicker := s.conf.RegionPicker.New() @@ -597,13 +598,18 @@ func (s *V1Instance) SetPeers(peerInfo []PeerInfo) { peer := s.conf.RegionPicker.GetByPeerInfo(info) // If we don't have an existing PeerClient create a new one if peer == nil { - peer = NewPeerClient(PeerConfig{ + var err error + peer, err = NewPeerClient(PeerConfig{ TraceGRPC: s.conf.PeerTraceGRPC, Behavior: s.conf.Behaviors, TLS: s.conf.PeerTLS, Log: s.log, Info: info, }) + if err != nil { + s.log.Errorf("error connecting to peer %s: %s", info.GRPCAddress, err) + return + } } regionPicker.Add(peer) continue @@ -611,13 +617,18 @@ func (s *V1Instance) SetPeers(peerInfo []PeerInfo) { // If we don't have an existing PeerClient create a new one peer := s.conf.LocalPicker.GetByPeerInfo(info) if peer == nil { - peer = NewPeerClient(PeerConfig{ + var err error + peer, err = NewPeerClient(PeerConfig{ TraceGRPC: s.conf.PeerTraceGRPC, Behavior: s.conf.Behaviors, TLS: s.conf.PeerTLS, Log: s.log, Info: info, }) + if err != nil { + s.log.Errorf("error connecting to peer %s: %s", info.GRPCAddress, err) + return + } } localPicker.Add(peer) } diff --git a/peer_client.go b/peer_client.go index a39d9f02..39c13c14 100644 --- a/peer_client.go +++ b/peer_client.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "fmt" "sync" + "sync/atomic" "github.com/mailgun/holster/v4/clock" "github.com/mailgun/holster/v4/collections" @@ -33,8 +34,10 @@ import ( "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" ) type PeerPicker interface { @@ -45,24 +48,16 @@ type PeerPicker interface { Add(*PeerClient) } -type peerStatus int - -const ( - peerNotConnected peerStatus = iota - peerConnected - peerClosing -) - type PeerClient struct { - client PeersV1Client - conn *grpc.ClientConn - conf PeerConfig - queue chan *request - lastErrs *collections.LRUCache - - mutex sync.RWMutex // This mutex is for verifying the closing state of the client - status peerStatus // Keep the current status of the peer - wg sync.WaitGroup // This wait group is to monitor the number of in-flight requests + client PeersV1Client + conn *grpc.ClientConn + conf PeerConfig + queue chan *request + queueClosed atomic.Bool + lastErrs *collections.LRUCache + + wgMutex sync.RWMutex + wg sync.WaitGroup // Monitor the number of in-flight requests. GUARDED_BY(wgMutex) } type response struct { @@ -84,80 +79,39 @@ type PeerConfig struct { TraceGRPC bool } -func NewPeerClient(conf PeerConfig) *PeerClient { - return &PeerClient{ +// NewPeerClient tries to establish a connection to a peer in a non-blocking fashion. +// If batching is enabled, it also starts a goroutine where batches will be processed. +func NewPeerClient(conf PeerConfig) (*PeerClient, error) { + peerClient := &PeerClient{ queue: make(chan *request, 1000), - status: peerNotConnected, conf: conf, lastErrs: collections.NewLRUCache(100), } -} - -// Connect establishes a GRPC connection to a peer -func (c *PeerClient) connect(ctx context.Context) (err error) { - // NOTE: To future self, this mutex is used here because we need to know if the peer is disconnecting and - // handle ErrClosing. Since this mutex MUST be here we take this opportunity to also see if we are connected. - // Doing this here encapsulates managing the connected state to the PeerClient struct. Previously a PeerClient - // was connected when `NewPeerClient()` was called however, when adding support for multi data centers having a - // PeerClient connected to every Peer in every data center continuously is not desirable. - - funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("PeerClient.connect")) - defer funcTimer.ObserveDuration() - lockTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("PeerClient.connect_RLock")) - - c.mutex.RLock() - lockTimer.ObserveDuration() - - if c.status == peerClosing { - c.mutex.RUnlock() - return &PeerErr{err: errors.New("already disconnecting")} - } - - if c.status == peerNotConnected { - // This mutex stuff looks wonky, but it allows us to use RLock() 99% of the time, while the 1% where we - // actually need to connect uses a full Lock(), using RLock() most of which should reduce the over head - // of a full lock on every call - - // Yield the read lock so we can get the RW lock - c.mutex.RUnlock() - c.mutex.Lock() - defer c.mutex.Unlock() + var opts []grpc.DialOption - // Now that we have the RW lock, ensure no else got here ahead of us. - if c.status == peerConnected { - return nil - } - - // Setup OpenTelemetry interceptor to propagate spans. - var opts []grpc.DialOption - - if c.conf.TraceGRPC { - opts = []grpc.DialOption{ - grpc.WithStatsHandler(otelgrpc.NewClientHandler()), - } + if conf.TraceGRPC { + opts = []grpc.DialOption{ + grpc.WithStatsHandler(otelgrpc.NewClientHandler()), } + } - if c.conf.TLS != nil { - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(c.conf.TLS))) - } else { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } + if conf.TLS != nil { + opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(conf.TLS))) + } else { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } - var err error - c.conn, err = grpc.Dial(c.conf.Info.GRPCAddress, opts...) - if err != nil { - return c.setLastErr(&PeerErr{err: errors.Wrapf(err, "failed to dial peer %s", c.conf.Info.GRPCAddress)}) - } - c.client = NewPeersV1Client(c.conn) - c.status = peerConnected + var err error + peerClient.conn, err = grpc.Dial(conf.Info.GRPCAddress, opts...) + if err != nil { + return nil, err + } + peerClient.client = NewPeersV1Client(peerClient.conn) - if !c.conf.Behavior.DisableBatching { - go c.runBatch() - } - return nil + if !conf.Behavior.DisableBatching { + go peerClient.runBatch() } - c.mutex.RUnlock() - return nil + return peerClient, nil } // Info returns PeerInfo struct that describes this PeerClient @@ -207,21 +161,13 @@ func (c *PeerClient) GetPeerRateLimit(ctx context.Context, r *RateLimitReq) (res // GetPeerRateLimits requests a list of rate limit statuses from a peer func (c *PeerClient) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimitsReq) (resp *GetPeerRateLimitsResp, err error) { - if err := c.connect(ctx); err != nil { - err = errors.Wrap(err, "Error in connect") - metricCheckErrorCounter.WithLabelValues("Connect error").Add(1) - return nil, c.setLastErr(err) - } - - // NOTE: This must be done within the RLock since calling Wait() in Shutdown() causes + // NOTE: This must be done within the Lock since calling Wait() in Shutdown() causes // a race condition if called within a separate go routine if the internal wg is `0` // when Wait() is called then Add(1) is called concurrently. - c.mutex.RLock() + c.wgMutex.Lock() c.wg.Add(1) - defer func() { - c.mutex.RUnlock() - defer c.wg.Done() - }() + c.wgMutex.Unlock() + defer c.wg.Done() resp, err = c.client.GetPeerRateLimits(ctx, r) if err != nil { @@ -241,17 +187,12 @@ func (c *PeerClient) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits // UpdatePeerGlobals sends global rate limit status updates to a peer func (c *PeerClient) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobalsReq) (resp *UpdatePeerGlobalsResp, err error) { - if err := c.connect(ctx); err != nil { - return nil, c.setLastErr(err) - } // See NOTE above about RLock and wg.Add(1) - c.mutex.RLock() + c.wgMutex.Lock() c.wg.Add(1) - defer func() { - c.mutex.RUnlock() - defer c.wg.Done() - }() + c.wgMutex.Unlock() + defer c.wg.Done() resp, err = c.client.UpdatePeerGlobals(ctx, r) if err != nil { @@ -296,29 +237,26 @@ func (c *PeerClient) getPeerRateLimitsBatch(ctx context.Context, r *RateLimitReq funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("PeerClient.getPeerRateLimitsBatch")) defer funcTimer.ObserveDuration() - if err := c.connect(ctx); err != nil { - err = errors.Wrap(err, "Error in connect") - return nil, c.setLastErr(err) - } - - // See NOTE above about RLock and wg.Add(1) - c.mutex.RLock() - if c.status == peerClosing { - err := &PeerErr{err: errors.New("already disconnecting")} - return nil, c.setLastErr(err) - } - - // Wait for a response or context cancel req := request{ resp: make(chan *response, 1), ctx: ctx, request: r, } + c.wgMutex.Lock() + c.wg.Add(1) + c.wgMutex.Unlock() + defer c.wg.Done() + // Enqueue the request to be sent peerAddr := c.Info().GRPCAddress metricBatchQueueLength.WithLabelValues(peerAddr).Set(float64(len(c.queue))) + if c.queueClosed.Load() { + // this check prevents "panic: send on close channel" + return nil, status.Error(codes.Canceled, "grpc: the client connection is closing") + } + select { case c.queue <- &req: // Successfully enqueued request. @@ -326,12 +264,7 @@ func (c *PeerClient) getPeerRateLimitsBatch(ctx context.Context, r *RateLimitReq return nil, errors.Wrap(ctx.Err(), "Context error while enqueuing request") } - c.wg.Add(1) - defer func() { - c.mutex.RUnlock() - c.wg.Done() - }() - + // Wait for a response or context cancel select { case re := <-req.resp: if re.err != nil { @@ -344,7 +277,7 @@ func (c *PeerClient) getPeerRateLimitsBatch(ctx context.Context, r *RateLimitReq } } -// run processes batching requests by waiting for requests to be queued. Send +// runBatch processes batching requests by waiting for requests to be queued. Send // the queue as a batch when either c.batchWait time has elapsed or the queue // reaches c.batchLimit. func (c *PeerClient) runBatch() { @@ -358,8 +291,8 @@ func (c *PeerClient) runBatch() { select { case r, ok := <-c.queue: - // If the queue has shutdown, we need to send the rest of the queue if !ok { + // If the queue has shutdown, we need to send the rest of the queue if len(queue) > 0 { c.sendBatch(ctx, queue) } @@ -426,7 +359,6 @@ func (c *PeerClient) sendBatch(ctx context.Context, queue []*request) { prop.Inject(r.ctx, &MetadataCarrier{Map: r.request.Metadata}) req.Requests = append(req.Requests, r.request) tracing.EndScope(r.ctx, nil) - } timeoutCtx, timeoutCancel := context.WithTimeout(ctx, c.conf.Behavior.BatchTimeout) @@ -470,31 +402,26 @@ func (c *PeerClient) sendBatch(ctx context.Context, queue []*request) { } } -// Shutdown will gracefully shutdown the client connection, until the context is cancelled +// Shutdown waits until all outstanding requests have finished or the context is cancelled. +// Then it closes the grpc connection. func (c *PeerClient) Shutdown(ctx context.Context) error { - // Take the write lock since we're going to modify the closing state - c.mutex.Lock() - if c.status == peerClosing || c.status == peerNotConnected { - c.mutex.Unlock() - return nil - } - defer c.mutex.Unlock() - - c.status = peerClosing + // ensure we don't leak goroutines, even if the Shutdown times out + defer c.conn.Close() - defer func() { - if c.conn != nil { - c.conn.Close() - } - }() - - // This allows us to wait on the waitgroup, or until the context - // has been cancelled. This doesn't leak goroutines, because - // closing the connection will kill any outstanding requests. waitChan := make(chan struct{}) go func() { + // drain in-flight requests + c.wgMutex.Lock() + defer c.wgMutex.Unlock() c.wg.Wait() + + // clear errors + c.lastErrs = collections.NewLRUCache(100) + + // signal that no more items will be sent + c.queueClosed.Store(true) close(c.queue) + close(waitChan) }() @@ -505,30 +432,3 @@ func (c *PeerClient) Shutdown(ctx context.Context) error { return nil } } - -// PeerErr is returned if the peer is not connected or is in a closing state -type PeerErr struct { - err error -} - -func (p *PeerErr) NotReady() bool { - return true -} - -func (p *PeerErr) Error() string { - return p.err.Error() -} - -func (p *PeerErr) Cause() error { - return p.err -} - -type notReadyErr interface { - NotReady() bool -} - -// IsNotReady returns true if the err is because the peer is not connected or in a closing state -func IsNotReady(err error) bool { - te, ok := err.(notReadyErr) - return ok && te.NotReady() -} diff --git a/peer_client_test.go b/peer_client_test.go index 99924bed..d739f40a 100644 --- a/peer_client_test.go +++ b/peer_client_test.go @@ -19,13 +19,15 @@ package gubernator_test import ( "context" "runtime" - "sync" + "strings" "testing" gubernator "github.com/mailgun/gubernator/v2" "github.com/mailgun/gubernator/v2/cluster" "github.com/mailgun/holster/v4/clock" - "github.com/stretchr/testify/assert" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) func TestPeerClientShutdown(t *testing.T) { @@ -56,17 +58,17 @@ func TestPeerClientShutdown(t *testing.T) { c := cases[i] t.Run(c.Name, func(t *testing.T) { - client := gubernator.NewPeerClient(gubernator.PeerConfig{ + client, err := gubernator.NewPeerClient(gubernator.PeerConfig{ Info: cluster.GetRandomPeer(cluster.DataCenterNone), Behavior: config, }) + require.NoError(t, err) - wg := sync.WaitGroup{} - wg.Add(threads) + wg := errgroup.Group{} + wg.SetLimit(threads) // Spawn a whole bunch of concurrent requests to test shutdown in various states for j := 0; j < threads; j++ { - go func() { - defer wg.Done() + wg.Go(func() error { ctx := context.Background() _, err := client.GetPeerRateLimit(ctx, &gubernator.RateLimitReq{ Hits: 1, @@ -74,28 +76,26 @@ func TestPeerClientShutdown(t *testing.T) { Behavior: c.Behavior, }) - isExpectedErr := false - - switch err.(type) { - case *gubernator.PeerErr: - isExpectedErr = true - case nil: - isExpectedErr = true + if err != nil { + if !strings.Contains(err.Error(), "client connection is closing") { + return errors.Wrap(err, "unexpected error in test") + } } - - assert.True(t, true, isExpectedErr) - - }() + return nil + }) } // yield the processor that way we allow other goroutines to start their request runtime.Gosched() - err := client.Shutdown(context.Background()) - assert.NoError(t, err) + shutDownErr := client.Shutdown(context.Background()) - wg.Wait() + err = wg.Wait() + if err != nil { + t.Error(err) + t.Fail() + } + require.NoError(t, shutDownErr) }) - } }