diff --git a/pkg/app.go b/pkg/app.go index 5cf3789f..640570bc 100644 --- a/pkg/app.go +++ b/pkg/app.go @@ -324,8 +324,8 @@ func (app *App) Start() { app.running = false }() - sg := make(chan os.Signal) - signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM) + sg := make(chan os.Signal, 1) + signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM) maxSessionCount := func() int64 { count := app.sessionPool.GetSessionCount() diff --git a/pkg/cluster/etcd_service_discovery.go b/pkg/cluster/etcd_service_discovery.go index 912eca6d..6812a400 100644 --- a/pkg/cluster/etcd_service_discovery.go +++ b/pkg/cluster/etcd_service_discovery.go @@ -27,12 +27,13 @@ import ( "strings" "sync" "time" + "github.com/topfreegames/pitaya/v3/pkg/config" "github.com/topfreegames/pitaya/v3/pkg/constants" "github.com/topfreegames/pitaya/v3/pkg/logger" "github.com/topfreegames/pitaya/v3/pkg/util" - clientv3 "go.etcd.io/etcd/client/v3" logutil "go.etcd.io/etcd/client/pkg/v3/logutil" + clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/client/v3/namespace" "google.golang.org/grpc" ) @@ -81,14 +82,14 @@ func NewEtcdServiceDiscovery( client = cli[0] } sd := &etcdServiceDiscovery{ - running: false, - server: server, - serverMapByType: make(map[string]map[string]*Server), - listeners: make([]SDListener, 0), - stopChan: make(chan bool), - stopLeaseChan: make(chan bool), - appDieChan: appDieChan, - cli: client, + running: false, + server: server, + serverMapByType: make(map[string]map[string]*Server), + listeners: make([]SDListener, 0), + stopChan: make(chan bool), + stopLeaseChan: make(chan bool), + appDieChan: appDieChan, + cli: client, syncServersRunning: make(chan bool), } @@ -300,7 +301,7 @@ func (sd *etcdServiceDiscovery) GetServersByType(serverType string) (map[string] // Create a new map to avoid concurrent read and write access to the // map, this also prevents accidental changes to the list of servers // kept by the service discovery. - ret := make(map[string]*Server,len(sd.serverMapByType[serverType])) + ret := make(map[string]*Server, len(sd.serverMapByType[serverType])) for k, v := range sd.serverMapByType[serverType] { ret[k] = v } @@ -615,8 +616,12 @@ func (sd *etcdServiceDiscovery) revoke() error { go func() { defer close(c) logger.Log.Debug("waiting for etcd revoke") - _, err := sd.cli.Revoke(context.TODO(), sd.leaseID) - c <- err + if sd.cli != nil { + _, err := sd.cli.Revoke(context.TODO(), sd.leaseID) + c <- err + } else { + c <- nil + } logger.Log.Debug("finished waiting for etcd revoke") }() select { diff --git a/pkg/cluster/grpc_rpc_server_test.go b/pkg/cluster/grpc_rpc_server_test.go index 31ca32ec..2cde149b 100644 --- a/pkg/cluster/grpc_rpc_server_test.go +++ b/pkg/cluster/grpc_rpc_server_test.go @@ -32,6 +32,7 @@ func TestGRPCServerInit(t *testing.T) { sv := getServer() gs, err := NewGRPCServer(c, sv, []metrics.Reporter{}) + assert.NoError(t, err) gs.SetPitayaServer(mockPitayaServer) err = gs.Init() assert.NoError(t, err) diff --git a/pkg/cluster/nats_rpc_common.go b/pkg/cluster/nats_rpc_common.go index 9736be79..3d59de5c 100644 --- a/pkg/cluster/nats_rpc_common.go +++ b/pkg/cluster/nats_rpc_common.go @@ -22,6 +22,9 @@ package cluster import ( "fmt" + "os" + "syscall" + "time" nats "github.com/nats-io/nats.go" "github.com/topfreegames/pitaya/v3/pkg/logger" @@ -32,6 +35,8 @@ func getChannel(serverType, serverID string) string { } func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) { + connectedCh := make(chan bool) + initialConnectErrorCh := make(chan error) natsOptions := append( options, nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { @@ -49,7 +54,19 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O logger.Log.Errorf("nats connection closed. reason: %q", nc.LastError()) if appDieChan != nil { - appDieChan <- true + select { + case appDieChan <- true: + return + case initialConnectErrorCh <- nc.LastError(): + logger.Log.Warnf("appDieChan not ready, sending error in initialConnectCh") + default: + logger.Log.Warnf("no termination channel available, sending SIGTERM to app") + err := syscall.Kill(os.Getpid(), syscall.SIGTERM) + if err != nil { + logger.Log.Errorf("could not kill the application via SIGTERM, exiting", err) + os.Exit(1) + } + } } }), nats.ErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { @@ -61,11 +78,34 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O logger.Log.Errorf(err.Error()) } }), + nats.ConnectHandler(func(*nats.Conn) { + connectedCh <- true + }), ) nc, err := nats.Connect(connectString, natsOptions...) if err != nil { return nil, err } - return nc, nil + maxConnTimeout := nc.Opts.Timeout + if nc.Opts.RetryOnFailedConnect { + // This is non-deterministic becase jitter TLS is different and we need to simplify + // the calculations. What we want to do is simply not block forever the call while + // we don't set a timeout so low that hinders our own reconnect config: + // maxReconnectTimeout = reconnectWait + reconnectJitter + reconnectTimeout + // connectionTimeout + (maxReconnectionAttemps * maxReconnectTimeout) + // Thus, the time.After considers 2 times this value + maxReconnectionTimeout := nc.Opts.ReconnectWait + nc.Opts.ReconnectJitter + nc.Opts.Timeout + maxConnTimeout += time.Duration(nc.Opts.MaxReconnect) * maxReconnectionTimeout + } + + logger.Log.Debugf("attempting nats connection for a max of %v", maxConnTimeout) + select { + case <-connectedCh: + return nc, nil + case err := <-initialConnectErrorCh: + return nil, err + case <-time.After(maxConnTimeout * 2): + return nil, fmt.Errorf("timeout setting up nats connection") + } } diff --git a/pkg/cluster/nats_rpc_common_test.go b/pkg/cluster/nats_rpc_common_test.go index da7eeba2..5f3c2c1c 100644 --- a/pkg/cluster/nats_rpc_common_test.go +++ b/pkg/cluster/nats_rpc_common_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/nats-io/nats-server/v2/test" nats "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" "github.com/topfreegames/pitaya/v3/pkg/helpers" @@ -77,3 +78,136 @@ func TestNatsRPCCommonCloseHandler(t *testing.T) { assert.True(t, ok) assert.True(t, value) } + +func TestSetupNatsConnReconnection(t *testing.T) { + t.Run("waits for reconnection on initial failure", func(t *testing.T) { + // Use an invalid address first to force initial connection failure + invalidAddr := "nats://invalid:4222" + validAddr := "nats://localhost:4222" + + urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr) + + go func() { + time.Sleep(50 * time.Millisecond) + ts := test.RunDefaultServer() + defer ts.Shutdown() + <-time.After(200 * time.Millisecond) + }() + + // Setup connection with retry enabled + appDieCh := make(chan bool) + conn, err := setupNatsConn( + urls, + appDieCh, + nats.ReconnectWait(10*time.Millisecond), + nats.MaxReconnects(5), + nats.RetryOnFailedConnect(true), + ) + + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.True(t, conn.IsConnected()) + + conn.Close() + }) + + t.Run("does not block indefinitely if all connect attempts fail", func(t *testing.T) { + invalidAddr := "nats://invalid:4222" + + appDieCh := make(chan bool) + done := make(chan any) + + ts := test.RunDefaultServer() + defer ts.Shutdown() + + go func() { + conn, err := setupNatsConn( + invalidAddr, + appDieCh, + nats.ReconnectWait(10*time.Millisecond), + nats.MaxReconnects(2), + nats.RetryOnFailedConnect(true), + ) + assert.Error(t, err) + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-appDieCh: + case <-done: + case <-time.After(250 * time.Millisecond): + t.Fail() + } + }) + + t.Run("if it fails to connect, exit with error even if appDieChan is not ready to listen", func(t *testing.T) { + invalidAddr := "nats://invalid:4222" + + appDieCh := make(chan bool) + done := make(chan any) + + ts := test.RunDefaultServer() + defer ts.Shutdown() + + go func() { + conn, err := setupNatsConn(invalidAddr, appDieCh) + assert.Error(t, err) + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-done: + case <-time.After(50 * time.Millisecond): + t.Fail() + } + }) + + t.Run("if connection takes too long, exit with error after waiting maxReconnTimeout", func(t *testing.T) { + invalidAddr := "nats://invalid:4222" + + appDieCh := make(chan bool) + done := make(chan any) + + initialConnectionTimeout := time.Nanosecond + maxReconnectionAtetmpts := 1 + reconnectWait := time.Nanosecond + reconnectJitter := time.Nanosecond + maxReconnectionTimeout := reconnectWait + reconnectJitter + initialConnectionTimeout + maxReconnTimeout := initialConnectionTimeout + (time.Duration(maxReconnectionAtetmpts) * maxReconnectionTimeout) + + maxTestTimeout := 100 * time.Millisecond + + // Assert that if it fails because of connection timeout the test will capture + assert.Greater(t, maxTestTimeout, maxReconnTimeout) + + ts := test.RunDefaultServer() + defer ts.Shutdown() + + go func() { + conn, err := setupNatsConn( + invalidAddr, + appDieCh, + nats.Timeout(initialConnectionTimeout), + nats.ReconnectWait(reconnectWait), + nats.MaxReconnects(maxReconnectionAtetmpts), + nats.ReconnectJitter(reconnectJitter, reconnectJitter), + nats.RetryOnFailedConnect(true), + ) + assert.Error(t, err) + assert.ErrorContains(t, err, "timeout setting up nats connection") + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-done: + case <-time.After(maxTestTimeout): + t.Fail() + } + }) +}