From c7476034498204ed01cb072c146fd8846ec5d3e0 Mon Sep 17 00:00:00 2001 From: Bart Smykla Date: Mon, 14 Jan 2019 15:43:20 +0000 Subject: [PATCH] Added reconnection logic when NATS is disconnected How was it tested: I deployed this version of nats-queue-worker and then simulated NATS disconnection Signed-off-by: Bart Smykla --- handler/handler.go | 58 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/handler/handler.go b/handler/handler.go index 5eef9d5..4f6c1f4 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -4,20 +4,67 @@ import ( "encoding/json" "fmt" "log" + "sync" + "time" "github.com/nats-io/go-nats-streaming" "github.com/openfaas/faas/gateway/queue" ) +var ( + maxReconnect = 5 + delayBetweenReconnect = 2 +) + // NatsQueue queue for work type NatsQueue struct { nc stan.Conn + ncMutex *sync.RWMutex ClientID string ClusterID string NATSURL string Topic string } +func (q *NatsQueue) connect() error { + nc, err := stan.Connect( + q.ClusterID, + q.ClientID, + stan.NatsURL(q.NATSURL), + stan.SetConnectionLostHandler(func(conn stan.Conn, err error) { + log.Printf("Disconnected from %s\n", q.NATSURL) + + q.reconnect(0) + }), + ) + + if err != nil { + return err + } + + q.ncMutex.Lock() + q.nc = nc + q.ncMutex.Unlock() + + return nil +} + +func (q *NatsQueue) reconnect(iteration int) { + log.Printf("Trying to reconnect (%d) to %s\n", iteration, q.NATSURL) + + if iteration < maxReconnect { + time.Sleep(time.Second * time.Duration(iteration*delayBetweenReconnect)) + + if err := q.connect(); err != nil { + log.Printf("Reconnection (%d) to %s failed", iteration, q.NATSURL) + + q.reconnect(iteration + 1) + } else { + log.Printf("Reconnection (%d) to %s succed", iteration, q.NATSURL) + } + } +} + // CreateNatsQueue ready for asynchronous processing func CreateNatsQueue(address string, port int, clientConfig NatsConfig) (*NatsQueue, error) { var err error @@ -27,15 +74,16 @@ func CreateNatsQueue(address string, port int, clientConfig NatsConfig) (*NatsQu clientID := clientConfig.GetClientID() clusterID := "faas-cluster" - nc, err := stan.Connect(clusterID, clientID, stan.NatsURL(natsURL)) queue1 := NatsQueue{ - nc: nc, ClientID: clientID, ClusterID: clusterID, NATSURL: natsURL, Topic: "faas-request", + ncMutex: &sync.RWMutex{}, } + err = queue1.connect() + return &queue1, err } @@ -50,7 +98,11 @@ func (q *NatsQueue) Queue(req *queue.Request) error { log.Println(err) } - err = q.nc.Publish(q.Topic, out) + q.ncMutex.RLock() + nc := q.nc + q.ncMutex.RUnlock() + + err = nc.Publish(q.Topic, out) return err }