Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(nats): wait for reconnects on setup #440

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ func (app *App) Start() {
app.running = false
}()

sg := make(chan os.Signal)
signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM)
sg := make(chan os.Signal, 1)
signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM)

maxSessionCount := func() int64 {
count := app.sessionPool.GetSessionCount()
Expand Down
29 changes: 17 additions & 12 deletions pkg/cluster/etcd_service_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ import (
"strings"
"sync"
"time"

"github.com/topfreegames/pitaya/v3/pkg/config"
"github.com/topfreegames/pitaya/v3/pkg/constants"
"github.com/topfreegames/pitaya/v3/pkg/logger"
"github.com/topfreegames/pitaya/v3/pkg/util"
clientv3 "go.etcd.io/etcd/client/v3"
logutil "go.etcd.io/etcd/client/pkg/v3/logutil"
clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/namespace"
"google.golang.org/grpc"
)
Expand Down Expand Up @@ -81,14 +82,14 @@ func NewEtcdServiceDiscovery(
client = cli[0]
}
sd := &etcdServiceDiscovery{
running: false,
server: server,
serverMapByType: make(map[string]map[string]*Server),
listeners: make([]SDListener, 0),
stopChan: make(chan bool),
stopLeaseChan: make(chan bool),
appDieChan: appDieChan,
cli: client,
running: false,
server: server,
serverMapByType: make(map[string]map[string]*Server),
listeners: make([]SDListener, 0),
stopChan: make(chan bool),
stopLeaseChan: make(chan bool),
appDieChan: appDieChan,
cli: client,
syncServersRunning: make(chan bool),
}

Expand Down Expand Up @@ -300,7 +301,7 @@ func (sd *etcdServiceDiscovery) GetServersByType(serverType string) (map[string]
// Create a new map to avoid concurrent read and write access to the
// map, this also prevents accidental changes to the list of servers
// kept by the service discovery.
ret := make(map[string]*Server,len(sd.serverMapByType[serverType]))
ret := make(map[string]*Server, len(sd.serverMapByType[serverType]))
for k, v := range sd.serverMapByType[serverType] {
ret[k] = v
}
Expand Down Expand Up @@ -615,8 +616,12 @@ func (sd *etcdServiceDiscovery) revoke() error {
go func() {
defer close(c)
logger.Log.Debug("waiting for etcd revoke")
_, err := sd.cli.Revoke(context.TODO(), sd.leaseID)
c <- err
if sd.cli != nil {
_, err := sd.cli.Revoke(context.TODO(), sd.leaseID)
c <- err
} else {
c <- nil
}
logger.Log.Debug("finished waiting for etcd revoke")
}()
select {
Expand Down
1 change: 1 addition & 0 deletions pkg/cluster/grpc_rpc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func TestGRPCServerInit(t *testing.T) {

sv := getServer()
gs, err := NewGRPCServer(c, sv, []metrics.Reporter{})
assert.NoError(t, err)
gs.SetPitayaServer(mockPitayaServer)
err = gs.Init()
assert.NoError(t, err)
Expand Down
44 changes: 42 additions & 2 deletions pkg/cluster/nats_rpc_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ package cluster

import (
"fmt"
"os"
"syscall"
"time"

nats "github.com/nats-io/nats.go"
"github.com/topfreegames/pitaya/v3/pkg/logger"
Expand All @@ -32,6 +35,8 @@ func getChannel(serverType, serverID string) string {
}

func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) {
connectedCh := make(chan bool)
initialConnectErrorCh := make(chan error)
natsOptions := append(
options,
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
Expand All @@ -49,7 +54,19 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O

logger.Log.Errorf("nats connection closed. reason: %q", nc.LastError())
if appDieChan != nil {
appDieChan <- true
select {
case appDieChan <- true:
return
case initialConnectErrorCh <- nc.LastError():
logger.Log.Warnf("appDieChan not ready, sending error in initialConnectCh")
default:
logger.Log.Warnf("no termination channel available, sending SIGTERM to app")
err := syscall.Kill(os.Getpid(), syscall.SIGTERM)
if err != nil {
logger.Log.Errorf("could not kill the application via SIGTERM, exiting", err)
os.Exit(1)
}
}
}
}),
nats.ErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) {
Expand All @@ -61,11 +78,34 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O
logger.Log.Errorf(err.Error())
}
}),
nats.ConnectHandler(func(*nats.Conn) {
connectedCh <- true
}),
)

nc, err := nats.Connect(connectString, natsOptions...)
if err != nil {
return nil, err
}
return nc, nil
maxConnTimeout := nc.Opts.Timeout
if nc.Opts.RetryOnFailedConnect {
// This is non-deterministic becase jitter TLS is different and we need to simplify
// the calculations. What we want to do is simply not block forever the call while
// we don't set a timeout so low that hinders our own reconnect config:
// maxReconnectTimeout = reconnectWait + reconnectJitter + reconnectTimeout
// connectionTimeout + (maxReconnectionAttemps * maxReconnectTimeout)
// Thus, the time.After considers 2 times this value
maxReconnectionTimeout := nc.Opts.ReconnectWait + nc.Opts.ReconnectJitter + nc.Opts.Timeout
maxConnTimeout += time.Duration(nc.Opts.MaxReconnect) * maxReconnectionTimeout
}

logger.Log.Debugf("attempting nats connection for a max of %v", maxConnTimeout)
select {
case <-connectedCh:
return nc, nil
case err := <-initialConnectErrorCh:
return nil, err
case <-time.After(maxConnTimeout * 2):
return nil, fmt.Errorf("timeout setting up nats connection")
}
}
134 changes: 134 additions & 0 deletions pkg/cluster/nats_rpc_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"testing"
"time"

"github.com/nats-io/nats-server/v2/test"
nats "github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"
"github.com/topfreegames/pitaya/v3/pkg/helpers"
Expand Down Expand Up @@ -77,3 +78,136 @@ func TestNatsRPCCommonCloseHandler(t *testing.T) {
assert.True(t, ok)
assert.True(t, value)
}

func TestSetupNatsConnReconnection(t *testing.T) {
t.Run("waits for reconnection on initial failure", func(t *testing.T) {
// Use an invalid address first to force initial connection failure
invalidAddr := "nats://invalid:4222"
validAddr := "nats://localhost:4222"

urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr)

go func() {
time.Sleep(50 * time.Millisecond)
ts := test.RunDefaultServer()
defer ts.Shutdown()
<-time.After(200 * time.Millisecond)
}()

// Setup connection with retry enabled
appDieCh := make(chan bool)
conn, err := setupNatsConn(
urls,
appDieCh,
nats.ReconnectWait(10*time.Millisecond),
nats.MaxReconnects(5),
nats.RetryOnFailedConnect(true),
)

assert.NoError(t, err)
assert.NotNil(t, conn)
assert.True(t, conn.IsConnected())

conn.Close()
})

t.Run("does not block indefinitely if all connect attempts fail", func(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

ts := test.RunDefaultServer()
defer ts.Shutdown()

go func() {
conn, err := setupNatsConn(
invalidAddr,
appDieCh,
nats.ReconnectWait(10*time.Millisecond),
nats.MaxReconnects(2),
nats.RetryOnFailedConnect(true),
)
assert.Error(t, err)
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-appDieCh:
case <-done:
case <-time.After(250 * time.Millisecond):
t.Fail()
}
})

t.Run("if it fails to connect, exit with error even if appDieChan is not ready to listen", func(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

ts := test.RunDefaultServer()
defer ts.Shutdown()

go func() {
conn, err := setupNatsConn(invalidAddr, appDieCh)
assert.Error(t, err)
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-done:
case <-time.After(50 * time.Millisecond):
t.Fail()
}
})

t.Run("if connection takes too long, exit with error after waiting maxReconnTimeout", func(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

initialConnectionTimeout := time.Nanosecond
maxReconnectionAtetmpts := 1
reconnectWait := time.Nanosecond
reconnectJitter := time.Nanosecond
maxReconnectionTimeout := reconnectWait + reconnectJitter + initialConnectionTimeout
maxReconnTimeout := initialConnectionTimeout + (time.Duration(maxReconnectionAtetmpts) * maxReconnectionTimeout)

maxTestTimeout := 100 * time.Millisecond

// Assert that if it fails because of connection timeout the test will capture
assert.Greater(t, maxTestTimeout, maxReconnTimeout)

ts := test.RunDefaultServer()
defer ts.Shutdown()

go func() {
conn, err := setupNatsConn(
invalidAddr,
appDieCh,
nats.Timeout(initialConnectionTimeout),
nats.ReconnectWait(reconnectWait),
nats.MaxReconnects(maxReconnectionAtetmpts),
nats.ReconnectJitter(reconnectJitter, reconnectJitter),
nats.RetryOnFailedConnect(true),
)
assert.Error(t, err)
assert.ErrorContains(t, err, "timeout setting up nats connection")
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-done:
case <-time.After(maxTestTimeout):
t.Fail()
}
})
}
Loading