From b5f6b86f604901e5304b19d169d3c0dc409d1604 Mon Sep 17 00:00:00 2001 From: Pedro Soares Date: Thu, 23 Jan 2025 09:50:28 -0300 Subject: [PATCH 1/3] fix(nats): wait for reconnects on setup If the initial connect fails, nats will spawn reconnect async handlers. Thus, we need to wait for all reconnects to be attempted before returning to the caller, otherwise we won't be making use of reconnections --- pkg/cluster/grpc_rpc_server_test.go | 1 + pkg/cluster/nats_rpc_common.go | 44 ++++++++- pkg/cluster/nats_rpc_common_test.go | 134 ++++++++++++++++++++++++++++ 3 files changed, 177 insertions(+), 2 deletions(-) diff --git a/pkg/cluster/grpc_rpc_server_test.go b/pkg/cluster/grpc_rpc_server_test.go index 31ca32ec..2cde149b 100644 --- a/pkg/cluster/grpc_rpc_server_test.go +++ b/pkg/cluster/grpc_rpc_server_test.go @@ -32,6 +32,7 @@ func TestGRPCServerInit(t *testing.T) { sv := getServer() gs, err := NewGRPCServer(c, sv, []metrics.Reporter{}) + assert.NoError(t, err) gs.SetPitayaServer(mockPitayaServer) err = gs.Init() assert.NoError(t, err) diff --git a/pkg/cluster/nats_rpc_common.go b/pkg/cluster/nats_rpc_common.go index 9736be79..3d59de5c 100644 --- a/pkg/cluster/nats_rpc_common.go +++ b/pkg/cluster/nats_rpc_common.go @@ -22,6 +22,9 @@ package cluster import ( "fmt" + "os" + "syscall" + "time" nats "github.com/nats-io/nats.go" "github.com/topfreegames/pitaya/v3/pkg/logger" @@ -32,6 +35,8 @@ func getChannel(serverType, serverID string) string { } func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) { + connectedCh := make(chan bool) + initialConnectErrorCh := make(chan error) natsOptions := append( options, nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { @@ -49,7 +54,19 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O logger.Log.Errorf("nats connection closed. reason: %q", nc.LastError()) if appDieChan != nil { - appDieChan <- true + select { + case appDieChan <- true: + return + case initialConnectErrorCh <- nc.LastError(): + logger.Log.Warnf("appDieChan not ready, sending error in initialConnectCh") + default: + logger.Log.Warnf("no termination channel available, sending SIGTERM to app") + err := syscall.Kill(os.Getpid(), syscall.SIGTERM) + if err != nil { + logger.Log.Errorf("could not kill the application via SIGTERM, exiting", err) + os.Exit(1) + } + } } }), nats.ErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { @@ -61,11 +78,34 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O logger.Log.Errorf(err.Error()) } }), + nats.ConnectHandler(func(*nats.Conn) { + connectedCh <- true + }), ) nc, err := nats.Connect(connectString, natsOptions...) if err != nil { return nil, err } - return nc, nil + maxConnTimeout := nc.Opts.Timeout + if nc.Opts.RetryOnFailedConnect { + // This is non-deterministic becase jitter TLS is different and we need to simplify + // the calculations. What we want to do is simply not block forever the call while + // we don't set a timeout so low that hinders our own reconnect config: + // maxReconnectTimeout = reconnectWait + reconnectJitter + reconnectTimeout + // connectionTimeout + (maxReconnectionAttemps * maxReconnectTimeout) + // Thus, the time.After considers 2 times this value + maxReconnectionTimeout := nc.Opts.ReconnectWait + nc.Opts.ReconnectJitter + nc.Opts.Timeout + maxConnTimeout += time.Duration(nc.Opts.MaxReconnect) * maxReconnectionTimeout + } + + logger.Log.Debugf("attempting nats connection for a max of %v", maxConnTimeout) + select { + case <-connectedCh: + return nc, nil + case err := <-initialConnectErrorCh: + return nil, err + case <-time.After(maxConnTimeout * 2): + return nil, fmt.Errorf("timeout setting up nats connection") + } } diff --git a/pkg/cluster/nats_rpc_common_test.go b/pkg/cluster/nats_rpc_common_test.go index da7eeba2..5f3c2c1c 100644 --- a/pkg/cluster/nats_rpc_common_test.go +++ b/pkg/cluster/nats_rpc_common_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/nats-io/nats-server/v2/test" nats "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" "github.com/topfreegames/pitaya/v3/pkg/helpers" @@ -77,3 +78,136 @@ func TestNatsRPCCommonCloseHandler(t *testing.T) { assert.True(t, ok) assert.True(t, value) } + +func TestSetupNatsConnReconnection(t *testing.T) { + t.Run("waits for reconnection on initial failure", func(t *testing.T) { + // Use an invalid address first to force initial connection failure + invalidAddr := "nats://invalid:4222" + validAddr := "nats://localhost:4222" + + urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr) + + go func() { + time.Sleep(50 * time.Millisecond) + ts := test.RunDefaultServer() + defer ts.Shutdown() + <-time.After(200 * time.Millisecond) + }() + + // Setup connection with retry enabled + appDieCh := make(chan bool) + conn, err := setupNatsConn( + urls, + appDieCh, + nats.ReconnectWait(10*time.Millisecond), + nats.MaxReconnects(5), + nats.RetryOnFailedConnect(true), + ) + + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.True(t, conn.IsConnected()) + + conn.Close() + }) + + t.Run("does not block indefinitely if all connect attempts fail", func(t *testing.T) { + invalidAddr := "nats://invalid:4222" + + appDieCh := make(chan bool) + done := make(chan any) + + ts := test.RunDefaultServer() + defer ts.Shutdown() + + go func() { + conn, err := setupNatsConn( + invalidAddr, + appDieCh, + nats.ReconnectWait(10*time.Millisecond), + nats.MaxReconnects(2), + nats.RetryOnFailedConnect(true), + ) + assert.Error(t, err) + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-appDieCh: + case <-done: + case <-time.After(250 * time.Millisecond): + t.Fail() + } + }) + + t.Run("if it fails to connect, exit with error even if appDieChan is not ready to listen", func(t *testing.T) { + invalidAddr := "nats://invalid:4222" + + appDieCh := make(chan bool) + done := make(chan any) + + ts := test.RunDefaultServer() + defer ts.Shutdown() + + go func() { + conn, err := setupNatsConn(invalidAddr, appDieCh) + assert.Error(t, err) + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-done: + case <-time.After(50 * time.Millisecond): + t.Fail() + } + }) + + t.Run("if connection takes too long, exit with error after waiting maxReconnTimeout", func(t *testing.T) { + invalidAddr := "nats://invalid:4222" + + appDieCh := make(chan bool) + done := make(chan any) + + initialConnectionTimeout := time.Nanosecond + maxReconnectionAtetmpts := 1 + reconnectWait := time.Nanosecond + reconnectJitter := time.Nanosecond + maxReconnectionTimeout := reconnectWait + reconnectJitter + initialConnectionTimeout + maxReconnTimeout := initialConnectionTimeout + (time.Duration(maxReconnectionAtetmpts) * maxReconnectionTimeout) + + maxTestTimeout := 100 * time.Millisecond + + // Assert that if it fails because of connection timeout the test will capture + assert.Greater(t, maxTestTimeout, maxReconnTimeout) + + ts := test.RunDefaultServer() + defer ts.Shutdown() + + go func() { + conn, err := setupNatsConn( + invalidAddr, + appDieCh, + nats.Timeout(initialConnectionTimeout), + nats.ReconnectWait(reconnectWait), + nats.MaxReconnects(maxReconnectionAtetmpts), + nats.ReconnectJitter(reconnectJitter, reconnectJitter), + nats.RetryOnFailedConnect(true), + ) + assert.Error(t, err) + assert.ErrorContains(t, err, "timeout setting up nats connection") + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-done: + case <-time.After(maxTestTimeout): + t.Fail() + } + }) +} From 6265c4f78d1625b611a727b66e507960ee0e2ed3 Mon Sep 17 00:00:00 2001 From: Pedro Soares Date: Mon, 27 Jan 2025 11:02:24 -0300 Subject: [PATCH 2/3] fix(app): init sig chan as buffered --- pkg/app.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/app.go b/pkg/app.go index 5cf3789f..640570bc 100644 --- a/pkg/app.go +++ b/pkg/app.go @@ -324,8 +324,8 @@ func (app *App) Start() { app.running = false }() - sg := make(chan os.Signal) - signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM) + sg := make(chan os.Signal, 1) + signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM) maxSessionCount := func() int64 { count := app.sessionPool.GetSessionCount() From b1e951530d06c5ffa44f056f6a1abc9cf8a459fa Mon Sep 17 00:00:00 2001 From: Pedro Soares Date: Mon, 27 Jan 2025 11:03:13 -0300 Subject: [PATCH 3/3] fix(etcd): prevent shutdown from crashing app If etcd module is shutdown before all connections are set up, it will crash trying to access sd.cli where it's still nil. Thus adding a check on shutdown --- pkg/cluster/etcd_service_discovery.go | 29 ++++++++++++++++----------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/pkg/cluster/etcd_service_discovery.go b/pkg/cluster/etcd_service_discovery.go index 912eca6d..6812a400 100644 --- a/pkg/cluster/etcd_service_discovery.go +++ b/pkg/cluster/etcd_service_discovery.go @@ -27,12 +27,13 @@ import ( "strings" "sync" "time" + "github.com/topfreegames/pitaya/v3/pkg/config" "github.com/topfreegames/pitaya/v3/pkg/constants" "github.com/topfreegames/pitaya/v3/pkg/logger" "github.com/topfreegames/pitaya/v3/pkg/util" - clientv3 "go.etcd.io/etcd/client/v3" logutil "go.etcd.io/etcd/client/pkg/v3/logutil" + clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/client/v3/namespace" "google.golang.org/grpc" ) @@ -81,14 +82,14 @@ func NewEtcdServiceDiscovery( client = cli[0] } sd := &etcdServiceDiscovery{ - running: false, - server: server, - serverMapByType: make(map[string]map[string]*Server), - listeners: make([]SDListener, 0), - stopChan: make(chan bool), - stopLeaseChan: make(chan bool), - appDieChan: appDieChan, - cli: client, + running: false, + server: server, + serverMapByType: make(map[string]map[string]*Server), + listeners: make([]SDListener, 0), + stopChan: make(chan bool), + stopLeaseChan: make(chan bool), + appDieChan: appDieChan, + cli: client, syncServersRunning: make(chan bool), } @@ -300,7 +301,7 @@ func (sd *etcdServiceDiscovery) GetServersByType(serverType string) (map[string] // Create a new map to avoid concurrent read and write access to the // map, this also prevents accidental changes to the list of servers // kept by the service discovery. - ret := make(map[string]*Server,len(sd.serverMapByType[serverType])) + ret := make(map[string]*Server, len(sd.serverMapByType[serverType])) for k, v := range sd.serverMapByType[serverType] { ret[k] = v } @@ -615,8 +616,12 @@ func (sd *etcdServiceDiscovery) revoke() error { go func() { defer close(c) logger.Log.Debug("waiting for etcd revoke") - _, err := sd.cli.Revoke(context.TODO(), sd.leaseID) - c <- err + if sd.cli != nil { + _, err := sd.cli.Revoke(context.TODO(), sd.leaseID) + c <- err + } else { + c <- nil + } logger.Log.Debug("finished waiting for etcd revoke") }() select {